diff --git a/mindspore/lite/test/run_test.sh b/mindspore/lite/test/run_test.sh index 9e9d677d9b..f048dbb364 100755 --- a/mindspore/lite/test/run_test.sh +++ b/mindspore/lite/test/run_test.sh @@ -12,19 +12,19 @@ cp -r ${CUR_DIR}/ut/tools/converter/parser/tflite/test_data/* ./ TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/ cp -fr $TEST_DATA_DIR/testPK ./data -#./lite-test --gtest_filter="*MindDataTestTensorDE*" -#./lite-test --gtest_filter="*MindDataTestEager*" -# -#./lite-test --gtest_filter="TestTfliteParser*" -# -#./lite-test --gtest_filter="*TestHebing*" -# -#./lite-test --gtest_filter=TestFcFp32* -#./lite-test --gtest_filter=TestConv1x1Fp32* -#./lite-test --gtest_filter=TestStrassenFp32* -#./lite-test --gtest_filter=TestDeConvolutionFp32* -# -#./lite-test --gtest_filter=TestPadInt8.* -#./lite-test --gtest_filter=TestDeconvInt8.* +./lite-test --gtest_filter="*MindDataTestTensorDE*" +./lite-test --gtest_filter="*MindDataTestEager*" + +./lite-test --gtest_filter="TestTfliteParser*" + +./lite-test --gtest_filter="*TestHebing*" + +./lite-test --gtest_filter=TestFcFp32* +./lite-test --gtest_filter=TestConv1x1Fp32* +./lite-test --gtest_filter=TestStrassenFp32* +./lite-test --gtest_filter=TestDeConvolutionFp32* + +./lite-test --gtest_filter=TestPadInt8.* +./lite-test --gtest_filter=TestDeconvInt8.* ./lite-test --gtest_filter="TestTfliteParser*" diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add1.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add.tflite similarity index 100% rename from mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add1.tflite rename to mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add.tflite diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add2.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add2.tflite deleted file mode 100644 index c0c379a60c..0000000000 Binary files a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add2.tflite and /dev/null differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add3.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add3.tflite deleted file mode 100644 index d36b81e327..0000000000 Binary files a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add3.tflite and /dev/null differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/argmax.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/argmax.tflite new file mode 100644 index 0000000000..6330efc0b8 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/argmax.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/concat.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/concat.tflite new file mode 100644 index 0000000000..11d0299154 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/concat.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/conv.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/conv.tflite new file mode 100644 index 0000000000..2cb33152c3 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/conv.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/deconv.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/deconv.tflite new file mode 100644 index 0000000000..0fc30506d0 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/deconv.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv1.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv1.tflite new file mode 100644 index 0000000000..d820f874f9 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv1.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv2.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv2.tflite new file mode 100644 index 0000000000..9baa40effd Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv2.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div1.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div.tflite similarity index 100% rename from mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div1.tflite rename to mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div.tflite diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div2.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div2.tflite deleted file mode 100644 index bc367d8f1e..0000000000 Binary files a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div2.tflite and /dev/null differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div3.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div3.tflite deleted file mode 100644 index a8b7e5ec7b..0000000000 Binary files a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div3.tflite and /dev/null differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/hardswish.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/hardswish.tflite new file mode 100644 index 0000000000..88f3097e75 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/hardswish.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sigmoid.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logistic.tflite similarity index 100% rename from mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sigmoid.tflite rename to mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logistic.tflite diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/avg_pooling.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mean_pooling.tflite similarity index 100% rename from mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/avg_pooling.tflite rename to mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mean_pooling.tflite diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul1.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul.tflite similarity index 100% rename from mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul1.tflite rename to mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul.tflite diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul2.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul2.tflite deleted file mode 100644 index 089de5cb2d..0000000000 Binary files a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul2.tflite and /dev/null differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul3.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul3.tflite deleted file mode 100644 index a0b90cde82..0000000000 Binary files a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul3.tflite and /dev/null differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/realdiv.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/realdiv.tflite new file mode 100644 index 0000000000..13f3cf78a1 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/realdiv.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/slice.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/slice.tflite new file mode 100644 index 0000000000..79891fd475 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/slice.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/stack.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/stack.tflite new file mode 100644 index 0000000000..2689d11685 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/stack.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub1.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub.tflite similarity index 100% rename from mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub1.tflite rename to mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub.tflite diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub2.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub2.tflite deleted file mode 100644 index dbca4e40a1..0000000000 Binary files a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub2.tflite and /dev/null differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub3.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub3.tflite deleted file mode 100644 index c223cdbf90..0000000000 Binary files a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub3.tflite and /dev/null differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/transpose.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/transpose.tflite new file mode 100644 index 0000000000..6e7e08695a Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/transpose.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_activation_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_activation_parser_test.cc index 24f2a34899..84ec400efc 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_activation_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_activation_parser_test.cc @@ -31,6 +31,12 @@ TEST_F(TestTfliteParserRelu, OpType) { ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; } +TEST_F(TestTfliteParserRelu, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsActivation(); + ASSERT_EQ(val->type, schema::ActivationType_RELU); +} + class TestTfliteParserRelu6 : public TestTfliteParser { public: TestTfliteParserRelu6() = default; @@ -43,6 +49,12 @@ TEST_F(TestTfliteParserRelu6, OpType) { ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; } +TEST_F(TestTfliteParserRelu6, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsActivation(); + ASSERT_EQ(val->type, schema::ActivationType_RELU6); +} + class TestTfliteParserTanh : public TestTfliteParser { public: TestTfliteParserTanh() = default; @@ -55,7 +67,45 @@ TEST_F(TestTfliteParserTanh, OpType) { ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; } -// logistic +TEST_F(TestTfliteParserTanh, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsActivation(); + ASSERT_EQ(val->type, schema::ActivationType_TANH); +} + +class TestTfliteParserLogistic : public TestTfliteParser { + public: + TestTfliteParserLogistic() = default; + void SetUp() override { meta_graph = LoadAndConvert("./logistic.tflite", ""); } +}; + +TEST_F(TestTfliteParserLogistic, OpType) { + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; +} +TEST_F(TestTfliteParserLogistic, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsActivation(); + ASSERT_EQ(val->type, schema::ActivationType_SIGMOID); +} + +class TestTfliteParserHardSwish : public TestTfliteParser { + public: + TestTfliteParserHardSwish() = default; + void SetUp() override { meta_graph = LoadAndConvert("./hardswish.tflite", ""); } +}; + +TEST_F(TestTfliteParserHardSwish, OpType) { + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; +} +TEST_F(TestTfliteParserHardSwish, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsActivation(); + ASSERT_EQ(val->type, schema::ActivationType_SIGMOID); +} class TestTfliteParserPrelu : public TestTfliteParser { public: @@ -73,12 +123,11 @@ TEST_F(TestTfliteParserPrelu, OpType) { } TEST_F(TestTfliteParserPrelu, AttrValue) { - std::vector slope(20, 0); - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPrelu(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsPrelu()->slope, slope); + auto val = meta_graph->nodes.front()->primitive->value; + std::vector slope(20, 0); + ASSERT_EQ(val.AsPrelu()->slope, slope); + ASSERT_EQ(val.type, schema::PrimitiveType_Prelu); } class TestTfliteParserLeakyRelu : public TestTfliteParser { @@ -94,12 +143,10 @@ TEST_F(TestTfliteParserLeakyRelu, OpType) { } TEST_F(TestTfliteParserLeakyRelu, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - - auto val = meta_graph->nodes.front()->primitive->value.AsLeakyReLU(); - ASSERT_NE(val, nullptr); - ASSERT_EQ(val->negativeSlope, 0.20000000298023224); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsLeakyReLU(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value; + ASSERT_EQ(val.AsLeakyReLU()->negativeSlope, 0.20000000298023224); + ASSERT_EQ(val.type, schema::PrimitiveType_LeakyReLU); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_addn_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_addn_parser_test.cc index 72a8237673..7480d19b6c 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_addn_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_addn_parser_test.cc @@ -35,10 +35,8 @@ TEST_F(TestTfliteParserAddN, OpType) { } TEST_F(TestTfliteParserAddN, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsAddN(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsAddN()->N, 4); + auto val = meta_graph->nodes.front()->primitive->value.AsAddN(); + ASSERT_EQ(val->N, 4); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmax_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmax_parser_test.cc new file mode 100644 index 0000000000..465930039d --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmax_parser_test.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserArgmax : public TestTfliteParser { + public: + TestTfliteParserArgmax() = default; + void SetUp() override { meta_graph = LoadAndConvert("./argmax.tflite", ""); } +}; + +TEST_F(TestTfliteParserArgmax, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMax) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserArgmax, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsArgMax(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsArgMax(); + ASSERT_EQ(val->axis, 1); + ASSERT_EQ(val->topK, 1); + ASSERT_EQ(val->axisType, 1); + ASSERT_EQ(val->keepDims, false); + ASSERT_EQ(val->outMaxValue, false); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmin_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmin_parser_test.cc index 0bb625ec87..03f9e1c1bf 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmin_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmin_parser_test.cc @@ -25,15 +25,14 @@ class TestTfliteParserArgmin : public TestTfliteParser { }; TEST_F(TestTfliteParserArgmin, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMin) << "wrong Op Type"; } TEST_F(TestTfliteParserArgmin, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsArgMin(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsArgMin(); ASSERT_EQ(val->axis, 1); ASSERT_EQ(val->topK, 1); diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_arithmetic_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_arithmetic_parser_test.cc index 64056ecde4..fdf08db8cf 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_arithmetic_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_arithmetic_parser_test.cc @@ -19,234 +19,57 @@ namespace mindspore { // doubleInputOp -class TestTfliteParserAdd1 : public TestTfliteParser { +class TestTfliteParserAdd : public TestTfliteParser { public: - TestTfliteParserAdd1() = default; - void SetUp() override { meta_graph = LoadAndConvert("./add1.tflite", ""); } + TestTfliteParserAdd() = default; + void SetUp() override { meta_graph = LoadAndConvert("./add.tflite", ""); } }; -TEST_F(TestTfliteParserAdd1, OpType) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type"; -} - -TEST_F(TestTfliteParserAdd1, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserAdd2 : public TestTfliteParser { - public: - TestTfliteParserAdd2() = default; - void SetUp() override { meta_graph = LoadAndConvert("./add2.tflite", ""); } -}; - -TEST_F(TestTfliteParserAdd2, OpType) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type"; -} - -TEST_F(TestTfliteParserAdd2, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserAdd3 : public TestTfliteParser { - public: - TestTfliteParserAdd3() = default; - void SetUp() override { meta_graph = LoadAndConvert("./add3.tflite", ""); } -}; - -TEST_F(TestTfliteParserAdd3, OpType) { +TEST_F(TestTfliteParserAdd, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type"; } -TEST_F(TestTfliteParserAdd3, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserSub1 : public TestTfliteParser { - public: - TestTfliteParserSub1() = default; - void SetUp() override { meta_graph = LoadAndConvert("./sub1.tflite", ""); } -}; - -TEST_F(TestTfliteParserSub1, OpType) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type"; -} - -TEST_F(TestTfliteParserSub1, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserSub2 : public TestTfliteParser { - public: - TestTfliteParserSub2() = default; - void SetUp() override { meta_graph = LoadAndConvert("./sub2.tflite", ""); } -}; - -TEST_F(TestTfliteParserSub2, OpType) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type"; -} - -TEST_F(TestTfliteParserSub2, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserSub3 : public TestTfliteParser { +class TestTfliteParserSub : public TestTfliteParser { public: - TestTfliteParserSub3() = default; - void SetUp() override { meta_graph = LoadAndConvert("./sub3.tflite", ""); } + TestTfliteParserSub() = default; + void SetUp() override { meta_graph = LoadAndConvert("./sub.tflite", ""); } }; -TEST_F(TestTfliteParserSub3, OpType) { +TEST_F(TestTfliteParserSub, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type"; } -TEST_F(TestTfliteParserSub3, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserMul1 : public TestTfliteParser { - public: - TestTfliteParserMul1() = default; - void SetUp() override { meta_graph = LoadAndConvert("./mul1.tflite", ""); } -}; - -TEST_F(TestTfliteParserMul1, OpType) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type"; -} - -TEST_F(TestTfliteParserMul1, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserMul2 : public TestTfliteParser { +class TestTfliteParserMul : public TestTfliteParser { public: - TestTfliteParserMul2() = default; - void SetUp() override { meta_graph = LoadAndConvert("./mul2.tflite", ""); } + TestTfliteParserMul() = default; + void SetUp() override { meta_graph = LoadAndConvert("./mul.tflite", ""); } }; -TEST_F(TestTfliteParserMul2, OpType) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type"; -} - -TEST_F(TestTfliteParserMul2, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserMul3 : public TestTfliteParser { - public: - TestTfliteParserMul3() = default; - void SetUp() override { meta_graph = LoadAndConvert("./mul3.tflite", ""); } -}; - -TEST_F(TestTfliteParserMul3, OpType) { +TEST_F(TestTfliteParserMul, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type"; } -TEST_F(TestTfliteParserMul3, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserDiv1 : public TestTfliteParser { - public: - TestTfliteParserDiv1() = default; - void SetUp() override { meta_graph = LoadAndConvert("./div1.tflite", ""); } -}; - -TEST_F(TestTfliteParserDiv1, OpType) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type"; -} - -TEST_F(TestTfliteParserDiv1, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserDiv2 : public TestTfliteParser { +class TestTfliteParserDiv : public TestTfliteParser { public: - TestTfliteParserDiv2() = default; - void SetUp() override { meta_graph = LoadAndConvert("./div2.tflite", ""); } + TestTfliteParserDiv() = default; + void SetUp() override { meta_graph = LoadAndConvert("./div.tflite", ""); } }; -TEST_F(TestTfliteParserDiv2, OpType) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type"; -} - -TEST_F(TestTfliteParserDiv2, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserDiv3 : public TestTfliteParser { - public: - TestTfliteParserDiv3() = default; - void SetUp() override { meta_graph = LoadAndConvert("./div3.tflite", ""); } -}; - -TEST_F(TestTfliteParserDiv3, OpType) { +TEST_F(TestTfliteParserDiv, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type"; } - -TEST_F(TestTfliteParserDiv3, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - class TestTfliteParserFloorDiv : public TestTfliteParser { public: TestTfliteParserFloorDiv() = default; @@ -254,6 +77,7 @@ class TestTfliteParserFloorDiv : public TestTfliteParser { }; TEST_F(TestTfliteParserFloorDiv, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorDiv) << "wrong Op Type"; @@ -266,12 +90,26 @@ class TestTfliteParserFloorMod : public TestTfliteParser { }; TEST_F(TestTfliteParserFloorMod, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorMod) << "wrong Op Type"; } -// realDiv +class TestTfliteParserRealDiv : public TestTfliteParser { + public: + TestTfliteParserRealDiv() = default; + void SetUp() override { + meta_graph = LoadAndConvert("./realdiv.tflite"); + } +}; + +TEST_F(TestTfliteParserRealDiv, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type"; +} class TestTfliteParserSquaredDifference : public TestTfliteParser { public: @@ -296,17 +134,15 @@ class TestTfliteParserPow : public TestTfliteParser { }; TEST_F(TestTfliteParserPow, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Power) << "wrong Op Type"; } TEST_F(TestTfliteParserPow, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPower(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsPower(); - ASSERT_EQ(val->scale, 1.0); ASSERT_EQ(val->shift, 0.0); ASSERT_EQ(val->power, 0.0); @@ -477,6 +313,7 @@ class TestTfliteParserFloor : public TestTfliteParser { }; TEST_F(TestTfliteParserFloor, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Floor) << "wrong Op Type"; diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_batch_to_space_nd_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_batch_to_space_nd_parser_test.cc index c78d8eac54..8091bc6598 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_batch_to_space_nd_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_batch_to_space_nd_parser_test.cc @@ -32,14 +32,12 @@ TEST_F(TestTfliteParserBatchToSpaceNd, OpType) { } TEST_F(TestTfliteParserBatchToSpaceNd, AttrValue) { - const std::vector blockShape{2, 2}; - const std::vector crops{0, 0, 2, 0}; - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsBatchToSpace(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->blockShape, blockShape); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->crops, crops); + auto val = meta_graph->nodes.front()->primitive->value.AsBatchToSpace(); + const std::vector blockShape = {2, 2}; + ASSERT_EQ(val->blockShape, blockShape); + const std::vector crops = {0, 0, 2, 0}; + ASSERT_EQ(val->crops, crops); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cast_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cast_parser_test.cc index 28fe10828c..4dd4c41a75 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cast_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cast_parser_test.cc @@ -35,12 +35,9 @@ TEST_F(TestTfliteParserCast, OpType) { } TEST_F(TestTfliteParserCast, AttrValue) { - // float32 --> int32 - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsCast(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->srcT, 43); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->dstT, 34); + auto val = meta_graph->nodes.front()->primitive->value.AsCast(); + ASSERT_EQ(val->srcT, 43); + ASSERT_EQ(val->dstT, 34); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_concat_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_concat_parser_test.cc new file mode 100644 index 0000000000..2d5fd6fe80 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_concat_parser_test.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserConcat : public TestTfliteParser { + public: + TestTfliteParserConcat() = default; + void SetUp() override { meta_graph = LoadAndConvert("./concat.tflite", ""); } +}; + +TEST_F(TestTfliteParserConcat, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Concat) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserConcat, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConcat(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsConcat(); + ASSERT_EQ(val->axis, 1); + ASSERT_EQ(val->n, 2); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc new file mode 100644 index 0000000000..dc2a763d49 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserConv : public TestTfliteParser { + public: + TestTfliteParserConv() = default; + void SetUp() override { meta_graph = LoadAndConvert("./conv.tflite", ""); } +}; + +TEST_F(TestTfliteParserConv, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserConv, AttrValue) { + ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr); + auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D(); + ASSERT_EQ(val->format, schema::Format_NHWC); + ASSERT_EQ(val->group, 1); + ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); + ASSERT_EQ(val->hasBias, true); + ASSERT_EQ(val->channelIn, 1); + ASSERT_EQ(val->channelOut, 4); + ASSERT_EQ(val->kernelH, 3); + ASSERT_EQ(val->kernelW, 3); + ASSERT_EQ(val->strideH, 1); + ASSERT_EQ(val->strideW, 1); + ASSERT_EQ(val->dilateH, 1); + ASSERT_EQ(val->dilateW, 1); + ASSERT_EQ(val->padMode, schema::PadMode_SAME); + ASSERT_EQ(val->padUp, 1); + ASSERT_EQ(val->padDown, 1); + ASSERT_EQ(val->padLeft, 1); + ASSERT_EQ(val->padRight, 1); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc new file mode 100644 index 0000000000..9ff4704c7c --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserDeConv : public TestTfliteParser { + public: + TestTfliteParserDeConv() = default; + void SetUp() override { meta_graph = LoadAndConvert("./deconv.tflite", ""); } +}; + +TEST_F(TestTfliteParserDeConv, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DeConv2D) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserDeConv, AttrValue) { + ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(), nullptr); + auto val = meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(); + ASSERT_EQ(val->format, schema::Format_NHWC); + ASSERT_EQ(val->group, 1); + ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); + ASSERT_EQ(val->hasBias, true); + + ASSERT_EQ(val->channelIn, 1); + ASSERT_EQ(val->channelOut, 4); + ASSERT_EQ(val->kernelH, 3); + ASSERT_EQ(val->kernelW, 3); + ASSERT_EQ(val->strideH, 1); + ASSERT_EQ(val->strideW, 1); + ASSERT_EQ(val->dilateH, 1); + ASSERT_EQ(val->dilateW, 1); + ASSERT_EQ(val->padMode, schema::PadMode_SAME); + ASSERT_EQ(val->padUp, 1); + ASSERT_EQ(val->padDown, 1); + ASSERT_EQ(val->padLeft, 1); + ASSERT_EQ(val->padRight, 1); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depth_to_space_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depth_to_space_parser_test.cc index 5283c0a85d..08abf71836 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depth_to_space_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depth_to_space_parser_test.cc @@ -35,11 +35,9 @@ TEST_F(TestTfliteParserDepthToSpace, OpType) { } TEST_F(TestTfliteParserDepthToSpace, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthToSpace(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->blockSize, 4); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->format, schema::Format_NHWC); + auto val = meta_graph->nodes.front()->primitive->value.AsDepthToSpace(); + ASSERT_EQ(val->blockSize, 4); + ASSERT_EQ(val->format, schema::Format_NHWC); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc new file mode 100644 index 0000000000..6ba9c4d1e6 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserDepthwiseConv1 : public TestTfliteParser { + public: + TestTfliteParserDepthwiseConv1() = default; + void SetUp() override { meta_graph = LoadAndConvert("./depthwise_conv1.tflite", ""); } +}; + +TEST_F(TestTfliteParserDepthwiseConv1, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserDepthwiseConv1, AttrValue) { + ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr); + auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D(); + ASSERT_EQ(val->format, schema::Format_NHWC); + ASSERT_EQ(val->group, 0); + ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); + ASSERT_EQ(val->hasBias, true); + ASSERT_EQ(val->channelIn, 1); + ASSERT_EQ(val->channelOut, 4); + ASSERT_EQ(val->kernelH, 3); + ASSERT_EQ(val->kernelW, 3); + ASSERT_EQ(val->strideH, 1); + ASSERT_EQ(val->strideW, 1); + ASSERT_EQ(val->dilateH, 1); + ASSERT_EQ(val->dilateW, 1); + ASSERT_EQ(val->padMode, schema::PadMode_SAME); + ASSERT_EQ(val->padUp, 1); + ASSERT_EQ(val->padDown, 1); + ASSERT_EQ(val->padLeft, 1); + ASSERT_EQ(val->padRight, 1); +} + +class TestTfliteParserDepthwiseConv2 : public TestTfliteParser { + public: + TestTfliteParserDepthwiseConv2() = default; + void SetUp() override { meta_graph = LoadAndConvert("./depthwise_conv2.tflite", ""); } +}; + +TEST_F(TestTfliteParserDepthwiseConv2, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DepthwiseConv2D) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserDepthwiseConv2, AttrValue) { + ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(), nullptr); + auto val = meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(); + ASSERT_EQ(val->format, schema::Format_NHWC); + ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); + ASSERT_EQ(val->hasBias, true); + ASSERT_EQ(val->channelIn, 2); + ASSERT_EQ(val->channelMultiplier, 1); + ASSERT_EQ(val->kernelH, 3); + ASSERT_EQ(val->kernelW, 3); + ASSERT_EQ(val->strideH, 1); + ASSERT_EQ(val->strideW, 1); + ASSERT_EQ(val->dilateH, 1); + ASSERT_EQ(val->dilateW, 1); + ASSERT_EQ(val->padMode, schema::PadMode_SAME); + ASSERT_EQ(val->padUp, 1); + ASSERT_EQ(val->padDown, 1); + ASSERT_EQ(val->padLeft, 1); + ASSERT_EQ(val->padRight, 1); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_fill_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_fill_parser_test.cc index 359d54b6e5..4ce37d77a4 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_fill_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_fill_parser_test.cc @@ -25,17 +25,15 @@ class TestTfliteParserFill : public TestTfliteParser { }; TEST_F(TestTfliteParserFill, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Fill) << "wrong Op Type"; } -TEST_F(TestTfliteParserFill, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - +TEST_F(TestTfliteParserFill, AttrValue) {; + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsFill(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsFill(); - std::vector dims = {9}; ASSERT_EQ(val->dims, dims); } diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_nd_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_nd_parser_test.cc index d6badec6a7..5c4af9e7f9 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_nd_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_nd_parser_test.cc @@ -25,15 +25,14 @@ class TestTfliteParserGatherNd : public TestTfliteParser { }; TEST_F(TestTfliteParserGatherNd, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_GatherNd) << "wrong Op Type"; } TEST_F(TestTfliteParserGatherNd, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGatherNd(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsGatherNd(); ASSERT_EQ(val->batchDims, 0); } diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_parser_test.cc index 13420abada..071738a15a 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_parser_test.cc @@ -25,15 +25,14 @@ class TestTfliteParserGather : public TestTfliteParser { }; TEST_F(TestTfliteParserGather, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Gather) << "wrong Op Type"; } TEST_F(TestTfliteParserGather, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGather(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsGather(); ASSERT_EQ(val->axis, 0); ASSERT_EQ(val->batchDims, 0); diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_lrn_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_lrn_parser_test.cc index 88a504a6ae..8db8351d4b 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_lrn_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_lrn_parser_test.cc @@ -25,6 +25,7 @@ class TestTfliteParserLRN : public TestTfliteParser { }; TEST_F(TestTfliteParserLRN, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, @@ -32,9 +33,7 @@ TEST_F(TestTfliteParserLRN, OpType) { } TEST_F(TestTfliteParserLRN, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization(); ASSERT_EQ(val->alpha, 1); ASSERT_EQ(val->beta, 0.5); diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_one_hot_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_one_hot_parser_test.cc index ed3d946a96..2c7a40b910 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_one_hot_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_one_hot_parser_test.cc @@ -32,12 +32,9 @@ TEST_F(TestTfliteParserOneHot, OpType) { } TEST_F(TestTfliteParserOneHot, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsOneHot(), nullptr); - // in OneHot parser axis = axis > 0 ? axis : axis + tensor_shape.size() - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsOneHot()->axis, 2); + auto val = meta_graph->nodes.front()->primitive->value.AsOneHot(); + ASSERT_EQ(val->axis, 2); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pad_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pad_parser_test.cc index ada4657759..1d33e1a8fc 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pad_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pad_parser_test.cc @@ -25,17 +25,15 @@ class TestTfliteParserPad : public TestTfliteParser { }; TEST_F(TestTfliteParserPad, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Pad) << "wrong Op Type"; } TEST_F(TestTfliteParserPad, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPad(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsPad(); - std::vector paddings = {1, 1, 2, 2, 3, 3, 4, 4}; ASSERT_EQ(val->paddings, paddings); } diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pooling_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pooling_parser_test.cc index d234a76a32..99d2500d59 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pooling_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pooling_parser_test.cc @@ -35,12 +35,8 @@ TEST_F(TestTfliteParserMaxPooling, OpType) { } TEST_F(TestTfliteParserMaxPooling, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPooling(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsPooling(); - ASSERT_NE(val, nullptr); ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->poolingMode, schema::PoolMode_MAX_POOLING); ASSERT_EQ(val->global, false); @@ -72,12 +68,8 @@ TEST_F(TestTfliteParserAvgPooling, OpType) { } TEST_F(TestTfliteParserAvgPooling, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPooling(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsPooling(); - ASSERT_NE(val, nullptr); ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->poolingMode, schema::PoolMode_MEAN_POOLING); ASSERT_EQ(val->global, false); diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_parser_test.cc index bf790bd842..86928b867b 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_parser_test.cc @@ -32,13 +32,9 @@ TEST_F(TestTfliteParserReduceMax, OpType) { } TEST_F(TestTfliteParserReduceMax, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); - ASSERT_NE(val, nullptr); - ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMax) << "wrong reduce mode"; + ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMax); ASSERT_EQ(val->keepDims, false); std::vector axes = {2}; ASSERT_EQ(val->axes, axes); @@ -58,13 +54,9 @@ TEST_F(TestTfliteParserReduceMin, OpType) { } TEST_F(TestTfliteParserReduceMin, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); - ASSERT_NE(val, nullptr); - ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMin) << "wrong reduce mode"; + ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMin); ASSERT_EQ(val->keepDims, false); std::vector axes = {2}; ASSERT_EQ(val->axes, axes); @@ -84,13 +76,9 @@ TEST_F(TestTfliteParserReduceProd, OpType) { } TEST_F(TestTfliteParserReduceProd, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); - ASSERT_NE(val, nullptr); - ASSERT_EQ(val->mode, schema::ReduceMode_ReduceProd) << "wrong reduce mode"; + ASSERT_EQ(val->mode, schema::ReduceMode_ReduceProd); ASSERT_EQ(val->keepDims, false); std::vector axes = {2}; ASSERT_EQ(val->axes, axes); @@ -111,13 +99,9 @@ TEST_F(TestTfliteParserSum, OpType) { } TEST_F(TestTfliteParserSum, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); - ASSERT_NE(val, nullptr); - ASSERT_EQ(val->mode, schema::ReduceMode_ReduceSum) << "wrong reduce mode"; + ASSERT_EQ(val->mode, schema::ReduceMode_ReduceSum); ASSERT_EQ(val->keepDims, false); std::vector axes = {2}; ASSERT_EQ(val->axes, axes); @@ -138,13 +122,9 @@ TEST_F(TestTfliteParserMean, OpType) { } TEST_F(TestTfliteParserMean, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); - ASSERT_NE(val, nullptr); - ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMean) << "wrong reduce mode"; + ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMean); ASSERT_EQ(val->keepDims, true); std::vector axes = {2, 3}; ASSERT_EQ(val->axes, axes); diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reshape_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reshape_parser_test.cc index a4eef7898b..3d29e9c192 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reshape_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reshape_parser_test.cc @@ -35,12 +35,9 @@ TEST_F(TestTfliteParserReshape, OpType) { } TEST_F(TestTfliteParserReshape, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReshape(), nullptr); - + auto val = meta_graph->nodes.front()->primitive->value.AsReshape(); std::vector shape = {3, 5, 20}; - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReshape()->shape, shape); // int32 + ASSERT_EQ(val->shape, shape); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_resize_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_resize_parser_test.cc index ab0c9a9a25..cfff0c88ae 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_resize_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_resize_parser_test.cc @@ -26,17 +26,15 @@ class TestTfliteParserResizeNN : public TestTfliteParser { }; TEST_F(TestTfliteParserResizeNN, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type"; } TEST_F(TestTfliteParserResizeNN, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsResize(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsResize(); - ASSERT_NE(val, nullptr); ASSERT_EQ(val->alignCorners, false); ASSERT_EQ(val->newHeight, 3); ASSERT_EQ(val->newWidth, 100); @@ -52,17 +50,15 @@ class TestTfliteParserResizeBilinear : public TestTfliteParser { }; TEST_F(TestTfliteParserResizeBilinear, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type"; } TEST_F(TestTfliteParserResizeBilinear, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsResize(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsResize(); - ASSERT_NE(val, nullptr); ASSERT_EQ(val->alignCorners, false); ASSERT_EQ(val->newHeight, 75); ASSERT_EQ(val->newWidth, 4); diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_parser_test.cc index ea1ffff935..e4a03440ba 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_parser_test.cc @@ -25,17 +25,15 @@ class TestTfliteParserReverse : public TestTfliteParser { }; TEST_F(TestTfliteParserReverse, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reverse) << "wrong Op Type"; } TEST_F(TestTfliteParserReverse, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverse(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsReverse(); - std::vector axis = {3}; ASSERT_EQ(val->axis, axis); } diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_sequence_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_sequence_parser_test.cc index a5f3e58d99..ba5b7c5220 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_sequence_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_sequence_parser_test.cc @@ -35,13 +35,11 @@ TEST_F(TestTfliteParserReverseSequence, OpType) { } TEST_F(TestTfliteParserReverseSequence, AttrValue) { - std::vector seq_length{7, 2, 3, 5}; - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverseSequence(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqLengths, seq_length); + auto val = meta_graph->nodes.front()->primitive->value.AsReverseSequence(); + ASSERT_EQ(val->seqAxis, 1); + ASSERT_EQ(val->seqAxis, 1); + std::vector seq_length = {7, 2, 3, 5}; + ASSERT_EQ(val->seqLengths, seq_length); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_slice_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_slice_parser_test.cc new file mode 100644 index 0000000000..655f114eed --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_slice_parser_test.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserSlice : public TestTfliteParser { + public: + TestTfliteParserSlice() = default; + + void SetUp() override { meta_graph = LoadAndConvert("./slice.tflite"); } +}; + +TEST_F(TestTfliteParserSlice, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Slice) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserSlice, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSlice(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsSlice(); + ASSERT_EQ(val->format, schema::Format_NHWC); + std::vector begin = {1, 0, 0}; + ASSERT_EQ(val->begin, begin); + std::vector size = {1, 1, 3}; + ASSERT_EQ(val->size, size); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_softmax_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_softmax_parser_test.cc index c6fb258e6e..88488f946f 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_softmax_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_softmax_parser_test.cc @@ -35,11 +35,9 @@ TEST_F(TestTfliteParserSoftmax, OpType) { } TEST_F(TestTfliteParserSoftmax, AttrValue) { -ASSERT_NE(meta_graph, nullptr); -ASSERT_GT(meta_graph->nodes.size(), 0); -ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); -ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSoftMax(), nullptr); -ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSoftMax()->axis, -1); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSoftMax(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsSoftMax(); + ASSERT_EQ(val->axis, -1); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser_test.cc index dab7110c1a..ae80d1481a 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser_test.cc @@ -35,13 +35,11 @@ TEST_F(TestTfliteParserSpaceToBatchND, OpType) { } TEST_F(TestTfliteParserSpaceToBatchND, AttrValue) { - std::vector blockshape{2, 2}; - std::vector padding{0, 0, 2, 0}; - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->blockShape, blockshape); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->paddings, padding); + auto val = meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(); + std::vector blockshape = {2, 2}; + ASSERT_EQ(val->blockShape, blockshape); + std::vector padding = {0, 0, 2, 0}; + ASSERT_EQ(val->paddings, padding); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_depth_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_depth_parser_test.cc index 785baff517..d1ed72bee2 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_depth_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_depth_parser_test.cc @@ -35,11 +35,9 @@ TEST_F(TestTfliteParserSpaceToDepth, OpType) { } TEST_F(TestTfliteParserSpaceToDepth, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->blockSize, 2); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->format, schema::Format_NHWC); + auto val = meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(); + ASSERT_EQ(val->blockSize, 2); + ASSERT_EQ(val->format, schema::Format_NHWC); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sparse_to_dense_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sparse_to_dense_parser_test.cc index 93e7badacf..c6b7fc1c1d 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sparse_to_dense_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sparse_to_dense_parser_test.cc @@ -35,16 +35,14 @@ TEST_F(TestTfliteParserSparseToDense, OpType) { } TEST_F(TestTfliteParserSparseToDense, AttrValue) { - std::vector outputShape{5, 5}; - std::vector sparseValue{1}; - std::vector defaultValue{0}; - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSparseToDense(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->outputShape, outputShape); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->sparseValue, sparseValue); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->defaultValue, defaultValue); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->validateIndices, false); + auto val = meta_graph->nodes.front()->primitive->value.AsSparseToDense(); + std::vector outputShape = {5, 5}; + ASSERT_EQ(val->outputShape, outputShape); + std::vector sparseValue = {1}; + ASSERT_EQ(val->sparseValue, sparseValue); + std::vector defaultValue = {0}; + ASSERT_EQ(val->defaultValue, defaultValue); + ASSERT_EQ(val->validateIndices, false); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_parser_test.cc index b7c9e38cb9..97cb01d999 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_parser_test.cc @@ -33,14 +33,12 @@ TEST_F(TestTfliteParserSplit, OpType) { } TEST_F(TestTfliteParserSplit, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr); - const std::vector sizeSplits{2, 2}; - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 2); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits); + auto val = meta_graph->nodes.front()->primitive->value.AsSplit(); + ASSERT_EQ(val->splitDim, 2); + ASSERT_EQ(val->numberSplit, 2); + const std::vector sizeSplits = {2, 2}; + ASSERT_EQ(val->sizeSplits, sizeSplits); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_v_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_v_parser_test.cc index 76eb564247..b0c6e78105 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_v_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_v_parser_test.cc @@ -33,14 +33,12 @@ TEST_F(TestTfliteParserSplitV, OpType) { } TEST_F(TestTfliteParserSplitV, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr); - const std::vector sizeSplits{1, 3}; - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 0); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits); + auto val = meta_graph->nodes.front()->primitive->value.AsSplit(); + ASSERT_EQ(val->splitDim, 0); + ASSERT_EQ(val->numberSplit, 2); + const std::vector sizeSplits = {1, 3}; + ASSERT_EQ(val->sizeSplits, sizeSplits); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_stack_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_stack_parser_test.cc new file mode 100644 index 0000000000..ff6d01841a --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_stack_parser_test.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserStack : public TestTfliteParser { + public: + TestTfliteParserStack() = default; + + void SetUp() override { meta_graph = LoadAndConvert("./stack.tflite"); } +}; + +TEST_F(TestTfliteParserStack, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Stack) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserStack, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStack(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsStack(); + ASSERT_EQ(val->axis, 1); + ASSERT_EQ(val->n, 2); + const std::vector isScale = {3, 2, 3}; + ASSERT_EQ(val->isScale, isScale); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_strided_slice_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_strided_slice_parser_test.cc index 88177cd9aa..ef6fa94a7a 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_strided_slice_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_strided_slice_parser_test.cc @@ -35,21 +35,19 @@ TEST_F(TestTfliteParserStridedSlice, OpType) { } TEST_F(TestTfliteParserStridedSlice, AttrValue) { - std::vector begin{1, -1, 0}; - std::vector end{2, -3, 3}; - std::vector stride{1, -1, 1}; - std::vector isscale{3, 2, 3}; - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStridedSlice(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->endMask, 0); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->begin, begin); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->end, end); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->stride, stride); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->isScale, isscale); + auto val = meta_graph->nodes.front()->primitive->value.AsStridedSlice(); + ASSERT_EQ(val->beginMask, 0); + ASSERT_EQ(val->endMask, 0); + ASSERT_EQ(val->beginMask, 0); + ASSERT_EQ(val->beginMask, 0); + std::vector begin = {1, -1, 0}; + ASSERT_EQ(val->begin, begin); + std::vector end = {2, -3, 3}; + ASSERT_EQ(val->end, end); + std::vector stride = {1, -1, 1}; + ASSERT_EQ(val->stride, stride); + std::vector isscale = {3, 2, 3}; + ASSERT_EQ(val->isScale, isscale); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_tile_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_tile_parser_test.cc index e025bad6f9..fe4d930acd 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_tile_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_tile_parser_test.cc @@ -35,11 +35,9 @@ TEST_F(TestTfliteParserTile, OpType) { } TEST_F(TestTfliteParserTile, AttrValue) { - std::vector multiply{2, 3, 4}; - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTile(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTile()->multiples, multiply); + auto val = meta_graph->nodes.front()->primitive->value.AsTile(); + std::vector multiply = {2, 3, 4}; + ASSERT_EQ(val->multiples, multiply); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc index 68e3629611..a2623e8390 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc @@ -35,13 +35,10 @@ TEST_F(TestTfliteParserTopKV2, OpType) { } TEST_F(TestTfliteParserTopKV2, AttrValue) { - // attr->sorted default is true - std::vector k{3}; - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->k, k); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->sorted, true); + auto val = meta_graph->nodes.front()->primitive->value.AsTopKV2(); + std::vector k = {3}; + ASSERT_EQ(val->k, k); + ASSERT_EQ(val->sorted, true); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_transpose_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_transpose_parser_test.cc new file mode 100644 index 0000000000..f5891da3bf --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_transpose_parser_test.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserTranspose : public TestTfliteParser { + public: + TestTfliteParserTranspose() = default; + + void SetUp() override { meta_graph = LoadAndConvert("./transpose.tflite"); } +}; + +TEST_F(TestTfliteParserTranspose, OpType) { +ASSERT_NE(meta_graph, nullptr); +ASSERT_GT(meta_graph->nodes.size(), 0); +ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); +ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserTranspose, AttrValue) { +ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTranspose(), nullptr); +auto val = meta_graph->nodes.front()->primitive->value.AsTranspose(); +ASSERT_EQ(val->conjugate, false); +std::vector perm = {1, 0}; +ASSERT_EQ(val->perm, perm); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unique_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unique_parser_test.cc index 5b49883b78..0273adbfe6 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unique_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unique_parser_test.cc @@ -35,10 +35,9 @@ TEST_F(TestTfliteParserUnique, OpType) { } TEST_F(TestTfliteParserUnique, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnique()->outType, 34); // int32 + auto val = meta_graph->nodes.front()->primitive->value.AsUnique(); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr); + ASSERT_EQ(val->outType, 34); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unstack_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unstack_parser_test.cc index 9c8e59232e..44c020d2c6 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unstack_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unstack_parser_test.cc @@ -35,11 +35,9 @@ TEST_F(TestTfliteParserUnstack, OpType) { } TEST_F(TestTfliteParserUnstack, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnstack(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->num, 5); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->axis, 1); + auto val = meta_graph->nodes.front()->primitive->value.AsUnstack(); + ASSERT_EQ(val->num, 5); + ASSERT_EQ(val->axis, 1); } } // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc index a87e960c8c..c863c502a8 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc @@ -353,7 +353,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { status = TransFilterFormat(weightTensor.get(), kCKHW2KHWC); } else if (weightTensor->format == schema::Format_KCHW) { status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); - } else if (weightTensor->format == schema::Format_CHWK) { + } else if (weightTensor->format == schema::Format_CHWK) { // from tflite status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); } else { MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; @@ -369,7 +369,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); - } else if (weightTensor->format == schema::Format_CHWK) { // from tf + } else if (weightTensor->format == schema::Format_CHWK) { // from tflite status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); } else { MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_flatten_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_flatten_parser.cc index 9eb1f05095..f24264cd58 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_flatten_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_flatten_parser.cc @@ -21,11 +21,16 @@ namespace lite { STATUS CaffeFlattenParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, std::vector *weightVec) { if (op == nullptr) { - // MS_LOGE("null pointer dereferencing."); + // MS_LOG(ERROR) << "null pointer dereferencing."; return RET_NULL_PTR; } - std::unique_ptr attr(new schema::ReshapeT()); - attr->format = schema::Format_NCHW; + std::unique_ptr attr(new schema::FullConnectionT()); + const caffe::FlattenParameter flattenParam = proto.flatten_param(); + + attr->axis = (int32_t)flattenParam.axis(); + attr->useAxis = true; + attr->hasBias = false; + attr->activationType = schema::ActivationType_NO_ACTIVATION; op->primitive = std::make_unique(); op->primitive->value.type = schema::PrimitiveType_Flatten; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc index 3ffb6ef901..50fb64f71d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc @@ -14,18 +14,21 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_activation_parser.h" #include #include #include -#include "tools/converter/parser/tflite/tflite_activation_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteActivationParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteActivationParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,13 +38,11 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr &t MS_LOG(ERROR) << "op->primitive is null"; return RET_NULL_PTR; } - std::unique_ptr attr(new schema::ActivationT()); std::vector node_name_str; Split(op->name, &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); - if (std::strcmp(node_name, "Relu") == 0) { MS_LOG(DEBUG) << "parse TfliteReluParser"; attr->type = schema::ActivationType_RELU; @@ -54,29 +55,31 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr &t } else if (std::strcmp(node_name, "Logistic") == 0) { MS_LOG(DEBUG) << "parse TfliteLogisticParser"; attr->type = schema::ActivationType_SIGMOID; - } else if (std::strcmp(node_name, "LeakyRelu") == 0) { - const auto &option = tfliteOp->builtin_options.AsLeakyReluOptions(); - if (option == nullptr) { - MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; - return RET_NULL_PTR; - } - attr->type = schema::ActivationType_LEAKY_RELU; - attr->alpha = option->alpha; - } else { - MS_LOG(ERROR) << "wrong activation type"; - return RET_ERROR; + } else if (std::strcmp(node_name, "HardSwish") == 0) { + MS_LOG(DEBUG) << "parse TfliteHardSwishParser"; + attr->type = schema::ActivationType_SIGMOID; } + attr->alpha = 0.2f; op->primitive->value.type = schema::PrimitiveType_Activation; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } STATUS TflitePreluParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) { + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TflitePreluParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -86,23 +89,64 @@ STATUS TflitePreluParser::Parse(const std::unique_ptr &tflite MS_LOG(ERROR) << "op->primitive is null"; return RET_NULL_PTR; } - - MS_LOG(DEBUG) << "paser TflitePreluParser"; std::unique_ptr attr(new schema::PreluT()); if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) { MS_LOG(ERROR) << "get pRelu -> slope failed"; return RET_ERROR; } - op->primitive->value.type = schema::PrimitiveType_Prelu; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + return RET_OK; +} + +STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteLeakyReluParser"; + + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr(new schema::LeakyReLUT()); + + const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + return RET_NULL_PTR; + } + attr->negativeSlope = tflite_attr->alpha; + + op->primitive->value.type = schema::PrimitiveType_LeakyReLU; + op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser()); TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser()); TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); +TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser()); TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser()); TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h index 0c4b350932..0223bcffac 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h @@ -14,13 +14,14 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_RELU_PARSER_H -#define PREDICT_TFLITE_RELU_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H -#include "tools/converter/parser/tflite/tflite_node_parser.h" -#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include #include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" namespace mindspore { namespace lite { @@ -29,11 +30,13 @@ class TfliteActivationParser : public TfliteNodeParser { public: TfliteActivationParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteReluParser : public TfliteActivationParser { @@ -56,9 +59,9 @@ class TfliteLogisticParser : public TfliteActivationParser { TfliteLogisticParser() : TfliteActivationParser() {} }; -class TfliteLeakyReluParser : public TfliteActivationParser { +class TfliteHardSwishParser : public TfliteActivationParser { public: - TfliteLeakyReluParser() : TfliteActivationParser() {} + TfliteHardSwishParser() : TfliteActivationParser() {} }; class TflitePreluParser : public TfliteNodeParser { @@ -68,12 +71,27 @@ class TflitePreluParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; +}; + +class TfliteLeakyReluParser : public TfliteNodeParser { + public: + TfliteLeakyReluParser() : TfliteNodeParser("LeakyRelu") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_RELU_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc index cf518ceafb..b88666fca9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc @@ -18,14 +18,20 @@ #include "tools/converter/parser/tflite/tflite_addn_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteAddNParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteAddNParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteAddNParser"; + + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -36,13 +42,19 @@ STATUS TfliteAddNParser::Parse(const std::unique_ptr &tfliteO return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteAddNParser"; std::unique_ptr attr(new schema::AddNT()); - attr->N = tfliteTensors.size() - 1; - + attr->N = tflite_tensors.size() - 1; op->primitive->value.type = schema::PrimitiveType_AddN; op->primitive->value.value = attr.release(); + + // set input + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h index bdc51bfc48..8bd1ef03ac 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_ADDN_PARSER_H -#define LITE_TFLITE_ADDN_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteAddNParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_ADDN_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc index b3fe114303..1ccf2ea465 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc @@ -17,16 +17,19 @@ #include "tools/converter/parser/tflite/tflite_argmax_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteArgmaxParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, +STATUS TfliteArgmaxParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,7 +40,6 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; std::unique_ptr attr(new schema::ArgMaxT()); attr->outMaxValue = false; @@ -45,9 +47,10 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr &tflit attr->keepDims = false; attr->axisType = 1; - auto axis_idx = tfliteOp->inputs[1]; - std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){}); - auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer]; + // get axis attr + auto axis_idx = tflite_op->inputs[1]; + std::for_each(tflite_tensors[axis_idx]->shape.begin(), tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha){}); + auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer]; if (buf_data == nullptr) { MS_LOG(ERROR) << "the buf data is null"; return RET_NULL_PTR; @@ -61,6 +64,11 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_ArgMax; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h index 0665a6b028..f7dc10cfaf 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_ARGMAX_PARSER_H -#define PREDICT_TFLITE_ARGMAX_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,15 @@ class TfliteArgmaxParser : public TfliteNodeParser { public: TfliteArgmaxParser() : TfliteNodeParser("Argmax") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_ARGMAX_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc index 37a47c0ea3..1ce77d5eba 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc @@ -17,14 +17,19 @@ #include "tools/converter/parser/tflite/tflite_argmin_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteArgminParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteArgminParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteArgminParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,7 +40,6 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteArgminParser"; std::unique_ptr attr(new schema::ArgMinT()); attr->outMaxValue = false; @@ -43,9 +47,11 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr &tflit attr->keepDims = false; attr->axisType = 1; - auto axis_idx = tfliteOp->inputs[1]; - std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){}); - auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer]; + // get axis attr + auto axis_idx = tflite_op->inputs[1]; + std::for_each(tflite_tensors[axis_idx]->shape.begin(), + tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha){}); + auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer]; if (buf_data == nullptr) { MS_LOG(ERROR) << "the buf data is null"; return RET_NULL_PTR; @@ -59,6 +65,11 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_ArgMin; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h index a02d4fe5e2..4213fc3211 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_ARGMIN_PARSER_H -#define PREDICT_TFLITE_ARGMIN_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,15 @@ class TfliteArgminParser : public TfliteNodeParser { public: TfliteArgminParser() : TfliteNodeParser("Argmin") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_ARGMIN_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc index 99eeb7f860..8a7eb16236 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc @@ -18,14 +18,17 @@ #include #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,124 +40,72 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr } std::vector node_name_str; - Split(op->name.data(), &node_name_str, "-"); + Split(op->name, &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); - - if (std::strcmp(node_name, "Add") == 0 - || std::strcmp(node_name, "Sub") == 0 - || std::strcmp(node_name, "Mul") == 0 - || std::strcmp(node_name, "Div") == 0) { - auto x_index = tfliteOp->inputs[0]; - const auto &x_tensor = tfliteTensors[x_index]; - if (x_tensor == nullptr) { - MS_LOG(ERROR) << "the first input is null"; + if (std::strcmp(node_name, "Add") == 0) { + MS_LOG(DEBUG) << "parse TfliteAddParser"; + std::unique_ptr attr(new schema::AddT()); + const auto &tfliteAttr = tflite_op->builtin_options.AsAddOptions(); + if (nullptr == tfliteAttr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } - auto &x_data = tfliteModelBuffer.at(x_tensor->buffer); - if (x_data == nullptr) { - MS_LOG(ERROR) << "the data of the first input is null"; + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + op->primitive->value.type = schema::PrimitiveType_Add; + op->primitive->value.value = attr.release(); + } else if (std::strcmp(node_name, "Sub") == 0) { + MS_LOG(DEBUG) << "parse TfliteSubParser"; + std::unique_ptr attr(new schema::SubT()); + const auto &tfliteAttr = tflite_op->builtin_options.AsSubOptions(); + if (nullptr == tfliteAttr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } - if (!x_data->data.empty()) { - std::vector x_tensors{x_tensor.get()}; - if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse the first tensor failed"; - return RET_ERROR; - } - } - - auto y_index = tfliteOp->inputs[1]; - const auto &y_tensor = tfliteTensors[y_index]; - if (y_tensor == nullptr) { - MS_LOG(ERROR) << "the second input is null"; + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + op->primitive->value.type = schema::PrimitiveType_Sub; + op->primitive->value.value = attr.release(); + } else if (std::strcmp(node_name, "Mul") == 0) { + MS_LOG(DEBUG) << "parse TfliteMulParser"; + std::unique_ptr attr(new schema::MulT()); + const auto &tfliteAttr = tflite_op->builtin_options.AsMulOptions(); + if (nullptr == tfliteAttr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } - auto &y_data = tfliteModelBuffer.at(y_tensor->buffer); - if (y_data == nullptr) { - MS_LOG(ERROR) << "the data of the second input is null"; + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + op->primitive->value.type = schema::PrimitiveType_Mul; + op->primitive->value.value = attr.release(); + } else if (std::strcmp(node_name, "Div") == 0) { + MS_LOG(DEBUG) << "parse TfliteDivParser"; + std::unique_ptr attr(new schema::DivT()); + const auto &tfliteAttr = tflite_op->builtin_options.AsDivOptions(); + if (nullptr == tfliteAttr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } - if (!y_data->data.empty()) { - std::vector y_tensors{y_tensor.get()}; - if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse the second tensor failed"; - return RET_ERROR; - } - } - - if (std::strcmp(node_name, "Add") == 0) { - MS_LOG(DEBUG) << "parse TfliteAddParser"; - std::unique_ptr attr(new schema::AddT()); - const auto &tfliteAttr = tfliteOp->builtin_options.AsAddOptions(); - if (nullptr == tfliteAttr) { - MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; - return RET_NULL_PTR; - } - attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); - op->primitive->value.type = schema::PrimitiveType_Add; - op->primitive->value.value = attr.release(); - return RET_OK; - } else if (std::strcmp(node_name, "Sub") == 0) { - MS_LOG(DEBUG) << "parse TfliteSubParser"; - std::unique_ptr attr(new schema::SubT()); - const auto &tfliteAttr = tfliteOp->builtin_options.AsSubOptions(); - if (nullptr == tfliteAttr) { - MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; - return RET_NULL_PTR; - } - attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); - op->primitive->value.type = schema::PrimitiveType_Sub; - op->primitive->value.value = attr.release(); - return RET_OK; - } else if (std::strcmp(node_name, "Mul") == 0) { - MS_LOG(DEBUG) << "parse TfliteMulParser"; - std::unique_ptr attr(new schema::MulT()); - const auto &tfliteAttr = tfliteOp->builtin_options.AsMulOptions(); - if (nullptr == tfliteAttr) { - MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; - return RET_NULL_PTR; - } - attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); - op->primitive->value.type = schema::PrimitiveType_Mul; - op->primitive->value.value = attr.release(); - return RET_OK; - } else if (std::strcmp(node_name, "Div") == 0) { - MS_LOG(DEBUG) << "parse TfliteDivParser"; - std::unique_ptr attr(new schema::DivT()); - const auto &tfliteAttr = tfliteOp->builtin_options.AsDivOptions(); - if (nullptr == tfliteAttr) { - MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; - return RET_NULL_PTR; - } - attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); - op->primitive->value.type = schema::PrimitiveType_Div; - op->primitive->value.value = attr.release(); - return RET_OK; - } + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + op->primitive->value.type = schema::PrimitiveType_Div; + op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "FloorDiv") == 0) { MS_LOG(DEBUG) << "parse TfliteFloorDivParser"; std::unique_ptr attr(new schema::FloorDivT()); op->primitive->value.type = schema::PrimitiveType_FloorDiv; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "FloorMod") == 0) { MS_LOG(DEBUG) << "parse TfliteFloorModParser"; std::unique_ptr attr(new schema::FloorModT()); op->primitive->value.type = schema::PrimitiveType_FloorMod; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "RealDiv") == 0) { MS_LOG(DEBUG) << "parse TfliteRealDivParser"; std::unique_ptr attr(new schema::RealDivT()); - op->primitive->value.type = schema::PrimitiveType_RealDiv; + op->primitive->value.type = schema::PrimitiveType_Div; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "SquaredDifference") == 0) { MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser"; std::unique_ptr attr(new schema::SquaredDifferenceT()); op->primitive->value.type = schema::PrimitiveType_SquaredDifference; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Pow") == 0) { MS_LOG(DEBUG) << "parse TflitePowParser"; std::unique_ptr attr(new schema::PowerT()); @@ -163,31 +114,35 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr attr->shift = 0.0f; op->primitive->value.type = schema::PrimitiveType_Power; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Maximum") == 0) { MS_LOG(DEBUG) << "parse TfliteMaximumParser"; std::unique_ptr attr(new schema::MaximumT()); op->primitive->value.type = schema::PrimitiveType_Maximum; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Minimum") == 0) { MS_LOG(DEBUG) << "parse TfliteMinimumParser"; std::unique_ptr attr(new schema::MinimumT()); op->primitive->value.type = schema::PrimitiveType_Minimum; op->primitive->value.value = attr.release(); - return RET_OK; - } else { - MS_LOG(ERROR) << "wrong op type"; - return RET_ERROR; } + + // set input + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } -STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -199,85 +154,79 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr } std::vector node_name_str; - Split(op->name.data(), &node_name_str, "-"); + Split(op->name, &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); if (std::strcmp(node_name, "Abs") == 0) { MS_LOG(DEBUG) << "parse TfliteAbsParser"; std::unique_ptr attr(new schema::AbsT()); op->primitive->value.type = schema::PrimitiveType_Abs; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Exp") == 0) { MS_LOG(DEBUG) << "parse TfliteExpParser"; std::unique_ptr attr(new schema::ExpT()); op->primitive->value.type = schema::PrimitiveType_Exp; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Sqrt") == 0) { MS_LOG(DEBUG) << "parse TfliteSqrtParser"; std::unique_ptr attr(new schema::SqrtT()); op->primitive->value.type = schema::PrimitiveType_Sqrt; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Rsqrt") == 0) { MS_LOG(DEBUG) << "parse TfliteRsqrtParser"; std::unique_ptr attr(new schema::RsqrtT()); op->primitive->value.type = schema::PrimitiveType_Rsqrt; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Square") == 0) { MS_LOG(DEBUG) << "parse TfliteSquareParser"; std::unique_ptr attr(new schema::SquareT()); op->primitive->value.type = schema::PrimitiveType_Square; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Sin") == 0) { MS_LOG(DEBUG) << "parse TfliteSinParser"; std::unique_ptr attr(new schema::SinT()); op->primitive->value.type = schema::PrimitiveType_Sin; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Cos") == 0) { MS_LOG(DEBUG) << "parse TfliteCosParser"; std::unique_ptr attr(new schema::CosT()); op->primitive->value.type = schema::PrimitiveType_Cos; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Log") == 0) { MS_LOG(DEBUG) << "parse TfliteLogParser"; std::unique_ptr attr(new schema::LogT()); op->primitive->value.type = schema::PrimitiveType_Log; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Round") == 0) { MS_LOG(DEBUG) << "parse TfliteRoundParser"; std::unique_ptr attr(new schema::RoundT()); op->primitive->value.type = schema::PrimitiveType_Round; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Ceil") == 0) { MS_LOG(DEBUG) << "parse TfliteCeilParser"; std::unique_ptr attr(new schema::CeilT()); op->primitive->value.type = schema::PrimitiveType_Ceil; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "flOOR") == 0) { MS_LOG(DEBUG) << "parse TfliteFloorParser"; std::unique_ptr attr(new schema::FloorT()); op->primitive->value.type = schema::PrimitiveType_Floor; op->primitive->value.value = attr.release(); - return RET_OK; - } else { - MS_LOG(ERROR) << "wrong op type"; - return RET_ERROR; } + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + return RET_OK; } -STATUS TfliteCompareOpParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteCompareOpParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -289,48 +238,47 @@ STATUS TfliteCompareOpParser::Parse(const std::unique_ptr &tf } std::vector node_name_str; - Split(op->name.data(), &node_name_str, "-"); + Split(op->name, &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); if (std::strcmp(node_name, "Equal") == 0) { MS_LOG(DEBUG) << "parse TfliteEqualParser"; std::unique_ptr attr(new schema::EqualT()); op->primitive->value.type = schema::PrimitiveType_Equal; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "NotEqual") == 0) { MS_LOG(DEBUG) << "parse TfliteNotEqualParser"; std::unique_ptr attr(new schema::NotEqualT()); op->primitive->value.type = schema::PrimitiveType_NotEqual; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Greater") == 0) { MS_LOG(DEBUG) << "parse TfliteGreaterParser"; std::unique_ptr attr(new schema::GreaterT()); op->primitive->value.type = schema::PrimitiveType_Greater; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "GreaterEqual") == 0) { MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser"; std::unique_ptr attr(new schema::GreaterEqualT()); op->primitive->value.type = schema::PrimitiveType_GreaterEqual; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Less") == 0) { MS_LOG(DEBUG) << "parse TfliteLessParser"; std::unique_ptr attr(new schema::LessT()); op->primitive->value.type = schema::PrimitiveType_Less; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "LessEqual") == 0) { MS_LOG(DEBUG) << "parse TfliteLessEqualParser"; std::unique_ptr attr(new schema::LessEqualT()); op->primitive->value.type = schema::PrimitiveType_LessEqual; op->primitive->value.value = attr.release(); - return RET_OK; - } else { - MS_LOG(ERROR) << "wrong op type"; - return RET_ERROR; } + + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + return RET_OK; } TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h index 8df29fb87b..d79da7a58a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_MATH_PARSER_H -#define PREDICT_TFLITE_MATH_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -29,11 +30,13 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser { public: TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteAddParser : public TfliteDoubleInputOpParser { @@ -96,11 +99,13 @@ class TfliteSingleInputOpParser : public TfliteNodeParser { public: TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteAbsParser : public TfliteSingleInputOpParser { @@ -163,11 +168,13 @@ class TfliteCompareOpParser : public TfliteNodeParser { public: TfliteCompareOpParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteEqualParser : public TfliteCompareOpParser { @@ -203,5 +210,5 @@ class TfliteLessEqualParser : public TfliteCompareOpParser { } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_MATH_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc index f42338bcaf..b1b4d5fad6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc @@ -19,14 +19,17 @@ #include #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -38,30 +41,32 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr } std::vector node_name_str; - Split(op->name.data(), &node_name_str, "-"); + Split(op->name, &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); if (std::strcmp(node_name, "BatchToSpace") == 0) { MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser"; } else if (std::strcmp(node_name, "BatchToSpaceND") == 0) { MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser"; - // in tflite - // blockShape should be a 1D tensor with dimension [spatial_dims_num] - // crops should be a 2D tensor with dimension [spatial_dims_num, 2] } std::unique_ptr attr(new schema::BatchToSpaceT()); - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->blockShape)) { MS_LOG(ERROR) << "get batchToSpace -> blockShape failed"; return RET_ERROR; } - if (GetTfliteData(tfliteOp->inputs[2], tfliteTensors, tfliteModelBuffer, attr->crops)) { + if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->crops)) { MS_LOG(ERROR) << "get batchToSpace -> crops failed"; return RET_ERROR; } op->primitive->value.type = schema::PrimitiveType_BatchToSpace; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h index def11ce9f9..8e28f3b4cf 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_BATCH_TO_SPACE_PARSER_H -#define LITE_TFLITE_BATCH_TO_SPACE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,8 +32,10 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { @@ -43,4 +46,4 @@ class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_BATCH_TO_SPACE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc index f73bcc9a87..199aff567b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc @@ -18,14 +18,19 @@ #include "tools/converter/parser/tflite/tflite_broadcast_to_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteBroadcastToParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -36,16 +41,20 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr & return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteBroadcastToParser"; std::unique_ptr attr(new schema::BroadcastToT()); - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dst_shape)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dst_shape)) { MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed"; return RET_ERROR; } op->primitive->value.type = schema::PrimitiveType_BroadcastTo; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h index 0bbebd449b..25478346fc 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_BROADCAST_TO_PARSER_H -#define LITE_TFLITE_BROADCAST_TO_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteBroadcastToParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_BROADCAST_TO_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc index 78fe45fb25..ee5120ddec 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc @@ -14,18 +14,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "tools/converter/parser/tflite/tflite_cast_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteCastParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteCastParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteCastParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -36,25 +40,28 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr &tfliteO return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteCastParser"; std::unique_ptr attr(new schema::CastT()); - const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]]; + const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]]; if (in_tensor == nullptr) { MS_LOG(ERROR) << "tensor is null"; return RET_NULL_PTR; } - attr->srcT = dtype_map[in_tensor->type]; - - const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]]; + attr->srcT = GetTfliteDataType(in_tensor->type); + const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]]; if (out_tensor == nullptr) { MS_LOG(ERROR) << "tensor is null"; return RET_NULL_PTR; } - attr->dstT = dtype_map[out_tensor->type]; + attr->dstT = GetTfliteDataType(out_tensor->type); op->primitive->value.type = schema::PrimitiveType_Cast; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h index ae1dca284c..151808dbd5 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_CAST_PARSER_ -#define LITE_TFLITE_CAST_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteCastParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_CAST_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc index a930233cee..86c4d98b89 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc @@ -17,14 +17,20 @@ #include "tools/converter/parser/tflite/tflite_concat_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteConcatParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteConcatParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteConcatParser"; + + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,20 +41,25 @@ STATUS TfliteConcatParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteConcatParser"; std::unique_ptr attr(new schema::ConcatT()); - const auto &tfliteAttr = tfliteOp->builtin_options.AsConcatenationOptions(); + const auto &tfliteAttr = tflite_op->builtin_options.AsConcatenationOptions(); if (tfliteAttr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } attr->axis = tfliteAttr->axis; - - attr->n = tfliteOp->inputs.size(); + attr->n = tflite_op->inputs.size(); op->primitive->value.type = schema::PrimitiveType_Concat; op->primitive->value.value = attr.release(); + + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h index d2a1acff77..eac2caf581 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_CONCAT_PARSER_H -#define PREDICT_TFLITE_CONCAT_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteConcatParser : public TfliteNodeParser { public: TfliteConcatParser() : TfliteNodeParser("Concat") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_CONCAT_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc index f075f10e85..8de26527b2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc @@ -17,14 +17,19 @@ #include "tools/converter/parser/tflite/tflite_conv_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteConvParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteConvParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteConvParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,60 +40,61 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr &tfliteO return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteConvParser"; std::unique_ptr attr(new schema::Conv2DT()); - const auto &tfliteAttr = tfliteOp->builtin_options.AsConv2DOptions(); - if (tfliteAttr == nullptr) { + const auto &tflite_attr = tflite_op->builtin_options.AsConv2DOptions(); + if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } attr->group = 1; - attr->strideW = tfliteAttr->stride_w; - attr->strideH = tfliteAttr->stride_h; - attr->dilateH = tfliteAttr->dilation_h_factor; - attr->dilateW = tfliteAttr->dilation_w_factor; - attr->padMode = GetPadMode(tfliteAttr->padding); + attr->strideW = tflite_attr->stride_w; + attr->strideH = tflite_attr->stride_h; + attr->dilateH = tflite_attr->dilation_h_factor; + attr->dilateW = tflite_attr->dilation_w_factor; + attr->padMode = GetPadMode(tflite_attr->padding); attr->format = schema::Format_NHWC; - attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); + attr->hasBias = true; // get the conv op weight tensor - auto weight_index = tfliteOp->inputs[1]; - const auto &weight_tensor = tfliteTensors[weight_index]; + auto weight_index = tflite_op->inputs[1]; + const auto &weight_tensor = tflite_tensors[weight_index]; if (weight_tensor == nullptr) { - MS_LOG(ERROR) << "weight_tensor is null"; + MS_LOG(ERROR) << "the weight tensor is null"; return RET_NULL_PTR; } - std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { - MS_LOG(ERROR) << "parse weight failed"; - return RET_ERROR; - } auto weight_shape = weight_tensor->shape; - attr->channelIn = weight_shape[KHWC_C]; - attr->channelOut = weight_shape[KHWC_K]; - attr->kernelW = weight_shape[KHWC_W]; - attr->kernelH = weight_shape[KHWC_H]; - - // get the conv op bias tensor - if (tfliteOp->inputs.size() == 3) { - attr->hasBias = true; - auto bias_index = tfliteOp->inputs[2]; - const auto &bias_tensor = tfliteTensors[bias_index]; - if (bias_tensor == nullptr) { - MS_LOG(ERROR) << "bias_tensor is null"; - return RET_NULL_PTR; - } - std::vector bias_tensors{bias_tensor.get()}; - if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse bias failed"; - return RET_ERROR; - } - } + attr->channelIn = weight_shape[3]; + attr->channelOut = weight_shape[0]; + attr->kernelH = weight_shape[1]; + attr->kernelW = weight_shape[2]; // calculate pad params + auto data_index = tflite_op->inputs[0]; + const auto &data_tensor = tflite_tensors[data_index]; + std::vector params; + if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, + attr->strideW, attr->kernelH, attr->kernelW, ¶ms) != RET_OK) { + MS_LOG(ERROR) << "get padding params failed"; + return RET_ERROR; + } else { + attr->padUp = params.at(0); + attr->padDown = params.at(1); + attr->padLeft = params.at(2); + attr->padRight = params.at(3); + } op->primitive->value.type = schema::PrimitiveType_Conv2D; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h index d2f523a0c3..abb5d889f7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_CONV_PARSER_H -#define PREDICT_TFLITE_CONV_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteConvParser : public TfliteNodeParser { public: TfliteConvParser() : TfliteNodeParser("Conv2D") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_CONV_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h index 9269502330..d510d74a16 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h @@ -19,6 +19,7 @@ #include #include +#include #include "tools/converter/converter.h" #include "tools/converter/parser/tflite/tflite_model_parser.h" #include "tools/converter/graphdef_transform.h" diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc index 61b2e3baf6..f994f90479 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc @@ -17,14 +17,19 @@ #include "tools/converter/parser/tflite/tflite_deconv_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,11 +40,10 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; std::unique_ptr attr(new schema::DeConv2DT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsTransposeConvOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions(); if (tflite_attr == nullptr) { - MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } @@ -50,26 +54,48 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tflit attr->dilateW = 1; attr->padMode = GetPadMode(tflite_attr->padding); attr->format = schema::Format_NHWC; + attr->activationType = schema::ActivationType_NO_ACTIVATION; + attr->hasBias = true; // get the conv op weight tensor - auto weight_index = tfliteOp->inputs[1]; - const auto &weight_tensor = tfliteTensors[weight_index]; + auto weight_index = tflite_op->inputs[1]; + const auto &weight_tensor = tflite_tensors[weight_index]; if (weight_tensor == nullptr) { - MS_LOG(ERROR) << "weight_tensor is null"; + MS_LOG(ERROR) << "the weight tensor is null"; return RET_NULL_PTR; } - std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { + auto weight_shape = weight_tensor->shape; + attr->channelIn = weight_shape[3]; + attr->channelOut = weight_shape[0]; + attr->kernelH = weight_shape[1]; + attr->kernelW = weight_shape[2]; + + // calculate pad params + auto data_index = tflite_op->inputs[2]; + const auto &data_tensor = tflite_tensors[data_index]; + std::vector params; + if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, + attr->strideW, attr->kernelH, attr->kernelW, ¶ms) != RET_OK) { + MS_LOG(ERROR) << "get padding params failed"; return RET_ERROR; + } else { + attr->padUp = params.at(0); + attr->padDown = params.at(1); + attr->padLeft = params.at(2); + attr->padRight = params.at(3); } - auto weight_shape = weight_tensor->shape; - attr->channelIn = weight_shape[CHWK_K]; - attr->channelOut = weight_shape[CHWK_C]; - attr->kernelW = weight_shape[CHWK_W]; - attr->kernelH = weight_shape[CHWK_H]; op->primitive->value.type = schema::PrimitiveType_DeConv2D; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h index 46e7e1b8b6..0a26ceb68f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_DECONV_PARSER_H -#define PREDICT_TFLITE_DECONV_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteDeConvParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_op_set, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_DECONV_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc index 04586855f5..b8fceaed9a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc @@ -18,14 +18,19 @@ #include "tools/converter/parser/tflite/tflite_depth_to_space_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -36,20 +41,23 @@ STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser"; std::unique_ptr attr(new schema::DepthToSpaceT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsDepthToSpaceOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsDepthToSpaceOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); return RET_NULL_PTR; } attr->blockSize = tflite_attr->block_size; - attr->format = schema::Format_NHWC; op->primitive->value.type = schema::PrimitiveType_DepthToSpace; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h index 3be9968d8d..6fac3d3cd1 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H -#define LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc index d3187aefc0..13354b61d1 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc @@ -17,65 +17,22 @@ #include "tools/converter/parser/tflite/tflite_depthwise_conv_parser.h" #include #include +#include #include "tools/common/node_util.h" namespace mindspore { namespace lite { -STATUS TfliteDepthwiseConv2DParser::ParseGroupDepthwiseConv(schema::CNodeT *op, - const std::unique_ptr &attr, - const std::unique_ptr &weightTensor, - TensorCache *tensor_cache) { - std::unique_ptr convAttr(new schema::Conv2DT); - convAttr->format = attr->format; - convAttr->channelIn = attr->channelIn; - convAttr->channelOut = attr->channelIn * attr->channelMultiplier; - convAttr->kernelH = attr->kernelH; - convAttr->kernelW = attr->kernelW; - convAttr->strideH = attr->strideH; - convAttr->strideW = attr->strideW; - convAttr->padMode = attr->padMode; - convAttr->padUp = attr->padUp; - convAttr->padDown = attr->padDown; - convAttr->padLeft = attr->padLeft; - convAttr->padRight = attr->padRight; - convAttr->dilateH = attr->dilateH; - convAttr->dilateW = attr->dilateW; - convAttr->hasBias = attr->hasBias; - convAttr->activationType = attr->activationType; - - auto weightTensorIndex = tensor_cache->FindTensor(weightTensor->name); - if (weightTensorIndex >= 0 && weightTensorIndex < tensor_cache->GetCachedTensor().size()) { - auto liteWeightTensor = tensor_cache->GetCachedTensor()[weightTensorIndex]; - if (liteWeightTensor->dataType == TypeId::kNumberTypeUInt8) { - // convert weight format KHWC -> CHWK - auto status = TransFilterFormat(liteWeightTensor, kKHWC2CHWK); - if (status != RET_OK) { - MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; - return RET_ERROR; - } - } - - if (liteWeightTensor->dataType == kNumberTypeFloat32 || liteWeightTensor->dataType == kNumberTypeFloat) { - // convert weight format KHWC -> CHWK - auto status = TransFilterFormat(liteWeightTensor, kKHWC2CHWK); - if (status != RET_OK) { - MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; - return RET_ERROR; - } - } - } - - op->primitive->value.type = schema::PrimitiveType_Conv2D; - op->primitive->value.value = convAttr.release(); - return RET_OK; -} STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -86,7 +43,6 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr attr(new schema::DepthwiseConv2DT()); const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions(); if (tflite_attr == nullptr) { @@ -100,15 +56,20 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptrpadMode = GetPadMode(tflite_attr->padding); attr->format = schema::Format_NHWC; attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); - // get the conv op weight tensor - auto input_index = tflite_op->inputs[0]; - const auto &input_tenosr = tflite_tensors[input_index]; - if (input_tenosr == nullptr) { - MS_LOG(ERROR) << "the first input is null"; + attr->hasBias = true; + attr->channelMultiplier = tflite_attr->depth_multiplier; + + // get the data tensor + auto data_index = tflite_op->inputs[1]; + const auto &data_tensor = tflite_tensors[data_index]; + if (data_tensor == nullptr) { + MS_LOG(ERROR) << "the data tensor is null"; return RET_NULL_PTR; } - auto input_shape = input_tenosr->shape; + auto data_shape = data_tensor->shape; + attr->channelIn = data_shape[3]; + // get the weight tensor auto weight_index = tflite_op->inputs[1]; const auto &weight_tensor = tflite_tensors[weight_index]; if (weight_tensor == nullptr) { @@ -116,38 +77,33 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptrshape; - attr->channelIn = input_shape[KHWC_C]; - attr->channelMultiplier = tflite_attr->depth_multiplier; - attr->kernelH = weight_shape[KHWC_H]; - attr->kernelW = weight_shape[KHWC_W]; - - std::vector weight_tensors{weight_tensor.get()}; - - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { - MS_LOG(ERROR) << "parse weight failed"; + attr->kernelH = weight_shape[1]; + attr->kernelW = weight_shape[2]; + + // calculate pad params + std::vector params; + if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, + attr->kernelH, attr->kernelW, ¶ms) != RET_OK) { + MS_LOG(ERROR) << "get padding params failed"; return RET_ERROR; - } - - if (tflite_op->inputs.size() == 3) { - attr->hasBias = true; - auto bias_index = tflite_op->inputs[2]; - const auto &bias_tensor = tflite_tensors[bias_index]; - std::vector bias_tensors{bias_tensor.get()}; - if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse bias failed"; - return RET_ERROR; - } - } - - if (attr->channelMultiplier > 1) { - if (RET_OK != ParseGroupDepthwiseConv(op, attr, weight_tensor, tensor_cache)) { - MS_LOG(ERROR) << "Parse Group DepthwiseConv failed"; - return RET_ERROR; - } } else { - op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; - op->primitive->value.value = attr.release(); + attr->padUp = params.at(0); + attr->padDown = params.at(1); + attr->padLeft = params.at(2); + attr->padRight = params.at(3); } + + op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h index 2e0b1a0d02..6e4022f4fb 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H -#define PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTHWISE_CONV_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTHWISE_CONV_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,20 +29,16 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser { public: TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) override; - - private: - STATUS ParseGroupDepthwiseConv(schema::CNodeT *op, - const std::unique_ptr &attr, - const std::unique_ptr &weightTensor, - TensorCache *tensor_cache); + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_CONV_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc index 74bde414e5..ab0ac6d906 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc @@ -16,15 +16,20 @@ #include "tools/converter/parser/tflite/tflite_dequantize_parser.h" #include #include +#include #include "tools/common/node_util.h" namespace mindspore { namespace lite { -STATUS TfliteDequantizeParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteDequantizeParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteDequantizeNParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,32 +40,30 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr &t return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteDequantizeNParser"; std::unique_ptr attr(new schema::CastT); // get the dequantize input tensor - const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]]; + const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]]; if (in_tensor == nullptr) { - MS_LOG(ERROR) << "weight_tensor is null"; + MS_LOG(ERROR) << "input tensor is null"; return RET_NULL_PTR; } - attr->srcT = dtype_map[in_tensor->type]; - - const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]]; + attr->srcT = GetTfliteDataType(in_tensor->type); + const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]]; if (out_tensor == nullptr) { - MS_LOG(ERROR) << "tensor is null"; + MS_LOG(ERROR) << "output tensor is null"; return RET_NULL_PTR; } - attr->dstT = dtype_map[out_tensor->type]; - std::vector weight_tensors{in_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { - MS_LOG(ERROR) << "parse weight failed"; - return RET_ERROR; - } + attr->dstT = GetTfliteDataType(out_tensor->type); op->primitive->value.type = schema::PrimitiveType_Fp16Cast; op->primitive->value.value = attr.release(); - return 0; + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + return RET_OK; } TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h index 3d6e521d7d..276bd3e748 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h @@ -13,11 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef LITE_TFLITE_DEQUANTIZE_PARSER_H -#define LITE_TFLITE_DEQUANTIZE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -27,13 +28,15 @@ class TfliteDequantizeParser : public TfliteNodeParser { public: TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_DEQUANTIZE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc index 624a0e5193..0d3c91528f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc @@ -17,16 +17,17 @@ #include "tools/converter/parser/tflite/tflite_expand_dims_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { +STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -40,7 +41,7 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr &t MS_LOG(DEBUG) << "parse TfliteExpandDimsParser"; std::unique_ptr attr(new schema::ExpandDimsT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsExpandDimsOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsExpandDimsOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h index aa867bc315..cdbda4b5b8 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_EXPAND_DIMS_PARSER_H -#define PREDICT_TFLITE_EXPAND_DIMS_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteExpandDimsParser : public TfliteNodeParser { public: TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_EXPAND_DIMS_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc deleted file mode 100644 index fa0e90ae11..0000000000 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/converter/parser/tflite/tflite_fakequant_parser.h" -#include -#include - -namespace mindspore { -namespace lite { -STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; - std::unique_ptr attr(new schema::FullConnectionT()); - - auto weight_index = tfliteOp->inputs[1]; - const auto &weight_tensor = tfliteTensors[weight_index]; - if (weight_tensor == nullptr) { - MS_LOG(ERROR) << "weight_tensor is null"; - return RET_NULL_PTR; - } - std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { - MS_LOG(ERROR) << "parse weight failed"; - return RET_ERROR; - } - - if (tfliteOp->inputs.size() == 3) { - attr->hasBias = true; - auto bias_index = tfliteOp->inputs[2]; - const auto &bias_tensor = tfliteTensors[bias_index]; - if (bias_tensor == nullptr) { - MS_LOG(ERROR) << "bias_tensor is null"; - return RET_NULL_PTR; - } - std::vector bias_tensors{bias_tensor.get()}; - if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse bias failed"; - return RET_ERROR; - } - } - attr->axis = 1; - - op->primitive->value.type = schema::PrimitiveType_FullConnection; - op->primitive->value.value = attr.release(); - return RET_OK; -} - -TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFakeQuantParser()); -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h deleted file mode 100644 index 101c6cfec1..0000000000 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef LITE_TFLITE_FAKEQUANT_PARSER_H -#define LITE_TFLITE_FAKEQUANT_PARSER_H - -#include -#include -#include "tools/converter/parser/tflite/tflite_node_parser.h" -#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" - -namespace mindspore { -namespace lite { -class TfliteFakeQuantParser : public TfliteNodeParser { - public: - TfliteFakeQuantParser() : TfliteNodeParser("FakeQuant") {} - - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; -}; -} // namespace lite -} // namespace mindspore - -#endif // LITE_TFLITE_FAKEQUANT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc index a7b3be7f66..70405dbf1e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc @@ -14,19 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_fill_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_fill_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteFillParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { +STATUS TfliteFillParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteFillParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,18 +40,22 @@ STATUS TfliteFillParser::Parse(const std::unique_ptr &tfliteO return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteFillParser"; std::unique_ptr attr(new schema::FillT()); - if (tfliteOp->inputs.size() > 1) { - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dims)) { - MS_LOG(ERROR) << "get Fill -> dims failed"; + if (tflite_op->inputs.size() > 1) { + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dims)) { + MS_LOG(ERROR) << "get fill -> dims failed"; return RET_ERROR; } } op->primitive->value.type = schema::PrimitiveType_Fill; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h index 5d8fdee06d..7bfcd5df99 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_FILL_PARSER_H -#define PREDICT_TFLITE_FILL_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_FILL_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_FILL_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteFillParser : public TfliteNodeParser { public: TfliteFillParser() : TfliteNodeParser("Fill") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_FILL_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_FILL_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc index 0084ad8c7e..ba02dc35ef 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc @@ -14,17 +14,21 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_fullyconnected_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_fullyconnected_parser.h" +#include +#include namespace mindspore { namespace lite { -STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,51 +39,43 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr node_name_str; + Split(op->name, &node_name_str, "-"); + const char *node_name = node_name_str.data()->c_str(); + if (std::strcmp(node_name, "FullyConnected") == 0) { + MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; + } else if (std::strcmp(node_name, "FakeQuant") == 0) { + MS_LOG(DEBUG) << "parse TfliteFakeQuantParser"; + } std::unique_ptr attr(new schema::FullConnectionT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsFullyConnectedOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsFullyConnectedOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; - return RET_NULL_PTR; - } - attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); - - auto weight_index = tfliteOp->inputs[1]; - const auto &weight_tensor = tfliteTensors[weight_index]; - if (weight_tensor == nullptr) { - MS_LOG(ERROR) << "weight_tensor is null"; return RET_NULL_PTR; } - std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { - MS_LOG(ERROR) << "parse weight failed"; - return RET_ERROR; - } - if (tfliteOp->inputs.size() == 3) { - attr->hasBias = true; - auto bias_index = tfliteOp->inputs[2]; - const auto &bias_tensor = tfliteTensors[bias_index]; - if (bias_tensor == nullptr) { - MS_LOG(ERROR) << "bias_tensor is null"; - return RET_NULL_PTR; - } - std::vector bias_tensors{bias_tensor.get()}; - if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse bias failed"; - return RET_ERROR; - } - } + attr->hasBias = true; attr->axis = 1; attr->useAxis = false; + attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); op->primitive->value.type = schema::PrimitiveType_FullConnection; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } TfliteNodeRegister g_tfliteFullyConnectedParser("FullyConnected", new TfliteFullyConnectedParser()); +TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFakeQuantParser());; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h index f41ab2e3c0..21fd8186ad 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_ADD_PARSER_H -#define PREDICT_TFLITE_ADD_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_FULLY_CONNECTED_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_FULLY_CONNECTED_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,21 @@ class TfliteFullyConnectedParser : public TfliteNodeParser { public: TfliteFullyConnectedParser() : TfliteNodeParser("FullyConnected") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; +}; + +class TfliteFakeQuantParser : public TfliteFullyConnectedParser { + public: + TfliteFakeQuantParser() : TfliteFullyConnectedParser() {} }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_ADD_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_FULLY_CONNECTED_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc index a954baef01..e8218bb431 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc @@ -17,16 +17,19 @@ #include "tools/converter/parser/tflite/tflite_gather_nd_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteGatherNdParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { +STATUS TfliteGatherNdParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteGatherNdParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,37 +40,18 @@ STATUS TfliteGatherNdParser::Parse(const std::unique_ptr &tfl return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteGatherNdParser"; std::unique_ptr attr(new schema::GatherNdT()); - - if (tfliteOp->inputs.size() != 2) { - MS_LOG(ERROR) << "The input size of gather_nd should be 2"; - return RET_ERROR; - } - - auto y_index = tfliteOp->inputs[1]; - const auto &y_tensor = tfliteTensors[y_index]; - if (y_tensor == nullptr) { - MS_LOG(ERROR) << "the second input is null"; - return RET_NULL_PTR; - } - auto &y_data = tfliteModelBuffer.at(y_tensor->buffer); - if (y_data == nullptr) { - MS_LOG(ERROR) << "the data of the second input is null"; - return RET_NULL_PTR; - } - if (!y_data->data.empty()) { - std::vector y_tensors{y_tensor.get()}; - if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse the second tensor failed"; - return RET_ERROR; - } - } - attr->batchDims = 0; op->primitive->value.type = schema::PrimitiveType_GatherNd; op->primitive->value.value = attr.release(); + + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h index c79d8aa753..4d9c3e525c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_GATHER_ND_PARSER_H -#define PREDICT_TFLITE_GATHER_ND_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_GATHER_ND_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_GATHER_ND_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteGatherNdParser : public TfliteNodeParser { public: TfliteGatherNdParser() : TfliteNodeParser("GatherND") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_GATHER_ND_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_GATHER_ND_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc index e49f95d511..cf388e94d8 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc @@ -17,16 +17,20 @@ #include "tools/converter/parser/tflite/tflite_gather_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteGatherParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { +STATUS TfliteGatherParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteGatherParser"; + + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,39 +41,25 @@ STATUS TfliteGatherParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteGatherParser"; std::unique_ptr attr(new schema::GatherT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsGatherOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsGatherOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } attr->axis = tflite_attr->axis; - attr->batchDims = 0; - auto y_index = tfliteOp->inputs[1]; - const auto &y_tensor = tfliteTensors[y_index]; - if (y_tensor == nullptr) { - MS_LOG(ERROR) << "the second input is null"; - return RET_NULL_PTR; - } - auto &y_data = tfliteModelBuffer.at(y_tensor->buffer); - if (y_data == nullptr) { - MS_LOG(ERROR) << "the data of the second input is null"; - return RET_NULL_PTR; - } - if (!y_data->data.empty()) { - std::vector y_tensors{y_tensor.get()}; - if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse the second tensor failed"; - return RET_ERROR; - } - } - op->primitive->value.type = schema::PrimitiveType_Gather; op->primitive->value.value = attr.release(); + + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h index 5dd842414a..08ea38976c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_GATHER_PARSER_H -#define PREDICT_TFLITE_GATHER_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_GATHER_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_GATHER_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteGatherParser : public TfliteNodeParser { public: TfliteGatherParser() : TfliteNodeParser("Gather") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_GATHER_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_GATHER_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.cc deleted file mode 100644 index e7ff131f73..0000000000 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "tools/converter/parser/tflite/tflite_hard_swish_parser.h" - -namespace mindspore { -namespace lite { -STATUS TfliteHardSwishParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - MS_LOG(INFO) << "parse TfliteHardSwishParser"; - std::unique_ptr attr(new schema::ActivationT()); - - attr->type = schema::ActivationType_HSWISH; - - op->primitive->value.type = schema::PrimitiveType_Activation; - op->primitive->value.value = attr.release(); - return RET_OK; -} - -TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser()); -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.h deleted file mode 100644 index 00de1d2458..0000000000 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_TFLITE_HARD_SWISH_PARSER_H -#define PREDICT_TFLITE_HARD_SWISH_PARSER_H - -#include -#include -#include "tools/converter/parser/tflite/tflite_node_parser.h" -#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" - -namespace mindspore { -namespace lite { -class TfliteHardSwishParser : public TfliteNodeParser { - public: - TfliteHardSwishParser() : TfliteNodeParser("HardSwish") {} - - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; -}; -} // namespace lite -} // namespace mindspore - -#endif // PREDICT_TFLITE_HARD_SWISH_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc index 078523ff40..5f4a4d1835 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc @@ -14,18 +14,21 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_logical_parser.h" #include #include #include -#include "tools/converter/parser/tflite/tflite_logical_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteLogicalParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteLogicalParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,7 +40,7 @@ STATUS TfliteLogicalParser::Parse(const std::unique_ptr &tfli } std::vector node_name_str; - Split(op->name.data(), &node_name_str, "-"); + Split(op->name, &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); if (std::strcmp(node_name, "LogicalAnd") == 0) { MS_LOG(DEBUG) << "parse TfliteLogicalAndParser"; @@ -45,21 +48,24 @@ STATUS TfliteLogicalParser::Parse(const std::unique_ptr &tfli op->primitive->value.type = schema::PrimitiveType_LogicalAnd; op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "LogicalNot") == 0) { - MS_LOG(INFO) << "parse TfliteLogicalNotParser"; + MS_LOG(DEBUG) << "parse TfliteLogicalNotParser"; std::unique_ptr attr(new schema::LogicalNotT()); op->primitive->value.type = schema::PrimitiveType_LogicalNot; op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "LogicalOr") == 0) { - MS_LOG(INFO) << "parse TfliteLogicalOrParser"; + MS_LOG(DEBUG) << "parse TfliteLogicalOrParser"; std::unique_ptr attr(new schema::LogicalOrT()); op->primitive->value.type = schema::PrimitiveType_LogicalOr; op->primitive->value.value = attr.release(); - } else { - MS_LOG(ERROR) << "wrong logical type"; - return RET_ERROR; } -return RET_OK; + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + return RET_OK; } TfliteNodeRegister g_TfliteLogicalAndParser("LogicalAnd", new TfliteLogicalAndParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h index 3608f1f12d..a56de847e6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_LOGICAL_AND_PARSER_H -#define PREDICT_TFLITE_LOGICAL_AND_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_AND_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_AND_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -29,12 +30,13 @@ class TfliteLogicalParser : public TfliteNodeParser { public: TfliteLogicalParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteLogicalAndParser : public TfliteLogicalParser { @@ -54,4 +56,4 @@ class TfliteLogicalOrParser : public TfliteLogicalParser { } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_LOGICAL_AND_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LOGICAL_AND_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc index 7d37779377..8986729d72 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc @@ -17,16 +17,20 @@ #include "tools/converter/parser/tflite/tflite_lrn_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteLRNParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, +STATUS TfliteLRNParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteLRNParser"; + + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,10 +41,9 @@ STATUS TfliteLRNParser::Parse(const std::unique_ptr &tfliteOp return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteLRNParser"; std::unique_ptr attr(new schema::LocalResponseNormalizationT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsLocalResponseNormalizationOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsLocalResponseNormalizationOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; @@ -52,6 +55,11 @@ STATUS TfliteLRNParser::Parse(const std::unique_ptr &tfliteOp op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h index b7eae4f978..a64179c64b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_LRN_PARSER_H -#define PREDICT_TFLITE_ADD_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LRN_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LRN_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteLRNParser : public TfliteNodeParser { public: TfliteLRNParser() : TfliteNodeParser("LocalResponseNorm") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_LRN_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LRN_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 7ea6fecfd3..b9990d0ace 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -17,6 +17,7 @@ #include "tools/converter/parser/tflite/tflite_model_parser.h" #include #include +#include #include "tools/common/graph_util.h" #include "tools/common/storage.h" #include "flatbuffers/flatbuffers.h" @@ -24,43 +25,45 @@ namespace mindspore { namespace lite { -TfliteModelParser::TfliteModelParser() {} +TfliteModelParser::TfliteModelParser() = default; -TfliteModelParser::~TfliteModelParser() {} +TfliteModelParser::~TfliteModelParser() = default; -std::unique_ptr TfliteModelParser::ReadTfliteModelFromFlat(const char *model_path) { +std::unique_ptr TfliteModelParser::ReadTfliteModel(const char *model_path) { size_t size; auto buf = ReadFile(model_path, &size); if (buf == nullptr) { MS_LOG(ERROR) << "the file buffer is nullptr"; - return nullptr; } flatbuffers::Verifier verify((const uint8_t *)buf, size); if (!tflite::VerifyModelBuffer(verify)) { MS_LOG(ERROR) << "the buffer is invalid and fail to create graph"; - return nullptr; } return tflite::UnPackModel(buf); } -std::string TfliteModelParser::GetTfliteNodeType(const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model) { - auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; - auto msOpType = GetMSOpType(tflite_op_type); - return msOpType; -} - -STATUS TfliteModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *sub_graphDef) { - std::vector tensors = tensor_cache.GetCachedTensor(); - for (auto iter : tensors) { - std::unique_ptr temp(iter); - temp->format = schema::Format_NHWC; - sub_graphDef->allTensors.emplace_back(move(temp)); +STATUS TfliteModelParser::CopyConstTensorData(const std::vector> &tflite_model_buffer, + const tflite::TensorT *tflite_tensor, + schema::TensorT *tensor) { + auto count = 1; + std::for_each(tflite_tensor->shape.begin(), tflite_tensor->shape.end(), [&](int32_t sha) { count *= sha; }); + auto data_size = count * GetDataTypeSize(TypeId(tensor->dataType)); + auto buffer_idx = tflite_tensor->buffer; + if (!tflite_model_buffer[buffer_idx]->data.empty()) { + tensor->data.resize(data_size); + if (memcpy_s(tensor->data.data(), data_size, tflite_model_buffer[buffer_idx]->data.data(), data_size)) { + MS_LOG(ERROR) << "memcpy tensor data failed"; + return RET_ERROR; + } + } else { + MS_LOG(ERROR) << "src tensor data is empty"; + return RET_ERROR; } return RET_OK; } -void TfliteModelParser::SetMsTensorFromTflite(const std::unique_ptr &tflite_tensor, - schema::TensorT *tensor) { + +void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr &tflite_tensor, + schema::TensorT *tensor) { std::unique_ptr quant_param(new QuantParamT()); if (!tflite_tensor->quantization->scale.empty()) { quant_param->scale = tflite_tensor->quantization->scale[0]; @@ -87,221 +90,228 @@ void TfliteModelParser::SetMsTensorFromTflite(const std::unique_ptrquantParams.emplace_back(std::move(quant_param)); } -STATUS TfliteModelParser::ParseTfliteQuantParams(const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op, - schema::CNodeT *op, TensorCache *tensor_cache) { - MS_ASSERT(op->outputIndex.size() == tflite_op->outputs.size()); - for (size_t i = 0; i < tflite_op->inputs.size() && i < op->inputIndex.size(); i++) { - const auto &tflite_tensor = tflite_subgraph->tensors[tflite_op->inputs.at(i)]; - if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && - tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) { - continue; - } - auto &inTensor = tensor_cache->GetCachedTensor().at(op->inputIndex.at(i)); - if (inTensor == nullptr) { - MS_LOG(ERROR) << "Parse tflite quant params inTensor is null"; +STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, + const QuantType &quant_type, + schema::MetaGraphT* sub_graph) { + int idx = 0; + for (const auto &tflite_op : tflite_subgraph->operators) { + auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; + auto op_type = GetMSOpType(tflite_op_type); + + std::unique_ptr op(new schema::CNodeT); + op->name = op_type + "-" + std::to_string(idx++); + op->quantType = quant_type; + MS_LOG(INFO) << "parse op: " << op->name.c_str(); + + // parse tflite op + auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type); + if (node_parser == nullptr) { + MS_LOG(ERROR) << "cannot find node parser, opType: " << op_type.c_str(); return RET_NULL_PTR; } - SetMsTensorFromTflite(tflite_tensor, inTensor); - } - for (size_t i = 0; i < tflite_op->outputs.size() && i < op->outputIndex.size(); i++) { - const auto &tflite_tensor = tflite_subgraph->tensors[tflite_op->outputs.at(i)]; - if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && - tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) { - continue; - } - auto &outTensor = tensor_cache->GetCachedTensor().at(op->outputIndex.at(i)); - if (outTensor == nullptr) { - MS_LOG(ERROR) << "Parse tflite quant params outTensor is null"; - return RET_NULL_PTR; + if (node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, op.get(), &tensorsId, + &tensorsFormat, &tensorsIdMap) != RET_OK) { + MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; + return RET_ERROR; } - SetMsTensorFromTflite(tflite_tensor, outTensor); + + // add + sub_graph->nodes.emplace_back(op.release()); + opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get(); + tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get(); } return RET_OK; } -STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op, schema::CNodeT *op, - TensorCache *tensorCache) { - for (const auto &index : tflite_op->outputs) { - const auto &tflite_tensor = tflite_subgraph->tensors[index]; - if (tflite_tensor == nullptr) { - MS_LOG(ERROR) << "tensor with id = " << index << " is null"; - return RET_ERROR; +STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr &tflite_subgraph, + const std::vector> &tflite_model_buffer, + schema::MetaGraphT* sub_graph) { + for (int i = 0; i < tensorsId.size(); i++) { + auto idx = tensorsId[i]; + if (idx < 0) { + idx += tflite_subgraph->tensors.size(); } + const auto &tflite_tensor = tflite_subgraph->tensors[idx]; std::unique_ptr tensor(new schema::TensorT()); + + tensor->format = tensorsFormat[i]; tensor->dataType = GetTfliteDataType(tflite_tensor->type); - // change dataType to int8 to fit ms-lite op - if (tensor->dataType == TypeId::kNumberTypeUInt8) { - tensor->dataType = TypeId::kNumberTypeInt8; - } tensor->dims = tflite_tensor->shape; - tensor->nodeType = schema::NodeType_Parameter; - auto opOutputIndex = tensorCache->AddTensor(tflite_tensor->name, tensor.release(), OP_OUTPUT); - op->outputIndex.emplace_back(opOutputIndex); - } - return RET_OK; -} -STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr &tflite_model, - const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op, schema::CNodeT *op, - TensorCache *tensor_cache) { - auto op_type = GetTfliteNodeType(tflite_op, tflite_model); - std::vector op_inputs(tflite_op->inputs); - if (op_type == "DeConv2D") { - reverse(op_inputs.begin(), op_inputs.end()); - } + // if graph input tensor + bool isInput = false; + auto tflite_inputs = tflite_subgraph->inputs; + for (int tflite_input : tflite_inputs) { + if (idx == tflite_input) { + isInput = true; + break; + } + } - for (const auto &tflite_index : op_inputs) { - const auto &tflite_tensor = tflite_subgraph->tensors[tflite_index]; - if (tflite_tensor == nullptr) { - MS_LOG(ERROR) << "tensor with id = " << tflite_index << " is null"; - return RET_ERROR; + // add data for const tensor + auto &tensor_buffer = tflite_model_buffer.at(tflite_tensor->buffer); + auto isConst = (!tensor_buffer->data.empty()); + if (isConst) { + CopyConstTensorData(tflite_model_buffer, tflite_tensor.get(), tensor.get()); } - auto tensor_name = tflite_tensor->name; - unsigned int index = tensor_cache->FindTensor(tensor_name); - if (index != -1) { - op->inputIndex.push_back(index); + + // set tensor attr + if (isInput || isConst) { + tensor->nodeType = schema::NodeType_ValueNode; + } else { + tensor->nodeType = schema::NodeType_Parameter; } + + // quant param + if (!(tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && + tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty())) { + SetTensorQuantParam(tflite_tensor, tensor.get()); + } + + tensors.push_back(tensor.release()); } + for (auto iter : tensors) { + std::unique_ptr temp(iter); + sub_graph->allTensors.emplace_back(move(temp)); + } return RET_OK; } -STATUS TfliteModelParser::ParseOp(const std::unique_ptr &tflite_model, - const std::unique_ptr &tflite_subgraph, - schema::MetaGraphT *subGraph, mindspore::lite::TensorCache *tensorCache, - const QuantType &quantType) { - auto i = 0; - for (const auto &tflite_op : tflite_subgraph->operators) { - auto opType = GetTfliteNodeType(tflite_op, tflite_model); - - std::unique_ptr op(new schema::CNodeT); - op->name = opType + "-" + std::to_string(i++); - op->quantType = quantType; - MS_LOG(INFO) << "parse op: " << op->name.c_str(); +STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr &tflite_subgraph, + schema::MetaGraphT* sub_graph) { + int id; - auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType); - if (node_parser == nullptr) { - MS_LOG(ERROR) << "cannot find node parser, opType: " << opType.c_str(); - continue; - // return RET_NULL_PTR; + // graph input + std::vector graph_inputs; + for (int i = 0; i < tflite_subgraph->inputs.size(); i++) { + const int idx = tflite_subgraph->inputs[i]; + if (idx < 0) { + id = idx + tflite_subgraph->tensors.size(); + } else { + id = idx; } - - auto status = node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, - tflite_model->operator_codes, op.get(), tensorCache, false); - if (status != RET_OK) { - MS_LOG(ERROR) << "node " << opType.c_str() << " parser failed"; - return RET_ERROR; - } - - status = SetOpOutputIdx(tflite_subgraph, tflite_op, op.get(), tensorCache); - if (status != RET_OK) { - MS_LOG(ERROR) << "set op " << opType.c_str() << " output index failed"; - return RET_ERROR; + auto iter = tensorsIdMap.find(id); + if (iter != tensorsIdMap.end()) { + graph_inputs.push_back(iter->second); } + } + sub_graph->inputIndex.assign(graph_inputs.begin(), graph_inputs.end()); - status = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, op.get(), tensorCache); - if (status != RET_OK) { - MS_LOG(ERROR) << "set op " << opType.c_str() << " input index failed"; - return RET_ERROR; + // graph output + std::vector graph_outputs; + for (int i = 0; i < tflite_subgraph->outputs.size(); i++) { + const int idx = tflite_subgraph->outputs[i]; + if (idx < 0) { + id = idx + tflite_subgraph->tensors.size(); + } else { + id = idx; } - - if (quantType != schema::QuantType_QUANT_NONE) { - status = ParseTfliteQuantParams(tflite_subgraph, tflite_op, op.get(), tensorCache); - if (status != RET_OK) { - MS_LOG(ERROR) << "parse op " << opType.c_str() << " quant parameters failed"; - return RET_ERROR; - } + auto iter = tensorsIdMap.find(id); + if (iter != tensorsIdMap.end()) { + graph_outputs.push_back(iter->second); } - - subGraph->nodes.emplace_back(std::move(op)); - opMap[subGraph->nodes.back()->name] = subGraph->nodes.back().get(); - tfliteOpMap[tflite_op.get()] = subGraph->nodes.back().get(); } + sub_graph->outputIndex.assign(graph_outputs.begin(), graph_outputs.end()); return RET_OK; } -void TfliteModelParser::SetInputTensor(const std::unique_ptr &tflite_subgraph, - TensorCache *tensor_cache) { - for (const auto &index : tflite_subgraph->inputs) { - const auto &tflite_tensor = tflite_subgraph->tensors[index]; - std::unique_ptr tensor(new schema::TensorT()); - tensor->format = schema::Format_NHWC; - tensor->dataType = GetTfliteDataType(tflite_tensor->type) != TypeId::kNumberTypeUInt8 - ? GetTfliteDataType(tflite_tensor->type) - : TypeId::kNumberTypeInt8; - tensor->nodeType = schema::NodeType_Parameter; - tensor->dims = tflite_tensor->shape; - tensor_cache->AddTensor(tflite_tensor->name, tensor.release(), GRAPH_INPUT); - } -} +STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT* sub_graph) { + for (auto &op : sub_graph->nodes) { + if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + auto attr = op->primitive->value.AsDepthwiseConv2D(); + if (attr->channelMultiplier > 1) { + // update attr + std::unique_ptr conv_attr(new schema::Conv2DT); + conv_attr->group = 0; + conv_attr->format = attr->format; + conv_attr->channelIn = attr->channelIn; + conv_attr->channelOut = attr->channelIn * attr->channelMultiplier; + conv_attr->kernelH = attr->kernelH; + conv_attr->kernelW = attr->kernelW; + conv_attr->strideH = attr->strideH; + conv_attr->strideW = attr->strideW; + conv_attr->padMode = attr->padMode; + conv_attr->padUp = attr->padUp; + conv_attr->padDown = attr->padDown; + conv_attr->padLeft = attr->padLeft; + conv_attr->padRight = attr->padRight; + conv_attr->dilateH = attr->dilateH; + conv_attr->dilateW = attr->dilateW; + conv_attr->hasBias = attr->hasBias; + conv_attr->activationType = attr->activationType; + + op->primitive->value.type = schema::PrimitiveType_Conv2D; + op->primitive->value.value = conv_attr.release(); -void TfliteModelParser::SetGraphTensorIndex(const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_model, - const mindspore::lite::TensorCache &tensorCache, - schema::MetaGraphT *subGraphDef) { - auto graphInputs = tensorCache.GetGraphInputs(); - subGraphDef->inputIndex.assign(graphInputs.begin(), graphInputs.end()); - - for (auto outputIndex : tflite_subgraph->outputs) { - int i = 0; - bool found = false; - for (const auto &tfliteOp : tflite_subgraph->operators) { - int j = 0; - auto opType = GetTfliteNodeType(tfliteOp, tflite_model); - std::string opName = opType + "-" + std::to_string(i++); - for (auto opOutputIndex : tfliteOp->outputs) { - if (outputIndex == opOutputIndex) { - subGraphDef->outputIndex.emplace_back(opMap[opName]->outputIndex[j]); - found = true; - break; + // update weight + auto weight_id = op->inputIndex[1]; + auto &weight_tensor = sub_graph->allTensors.at(weight_id); + if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) { + // convert weight format KHWC -> CHWK + auto status = TransFilterFormat(weight_tensor.get(), kKHWC2CHWK); + if (status != RET_OK) { + MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; + return RET_ERROR; + } + } + if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) { + // convert weight format KHWC -> CHWK + auto status = TransFilterFormat(weight_tensor.get(), kKHWC2CHWK); + if (status != RET_OK) { + MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; + return RET_ERROR; + } } - j++; - } - if (found) { - break; } } } + return RET_OK; } -MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile, - const QuantType &quantType) { - if (ValidateFileStr(modelFile, ".tflite") != RET_OK) { - MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.tflite"; + +MetaGraphT *TfliteModelParser::Parse(const std::string &model_file, + const std::string &weight_file, + const QuantType &quant_type) { + std::unique_ptr sub_graph(new schema::MetaGraphT); + sub_graph->name = "MS_model converted by TF-Lite"; + + // load graph + // std::unique_ptr tflite_model(new tflite::ModelT()); + std::unique_ptr tflite_model = ReadTfliteModel(model_file.c_str()); + + if (tflite_model->subgraphs.size() != 1) { + MS_LOG(ERROR) << "read tflite model subgraphs failed"; return nullptr; } + const auto &tflite_subgraph = tflite_model->subgraphs[0]; - std::unique_ptr tflite_model(new tflite::ModelT()); - tflite_model = ReadTfliteModelFromFlat(modelFile.c_str()); - if (tflite_model == nullptr) { - MS_LOG(ERROR) << "read tflite model failed"; + // convert op + if (ConvertOp(tflite_model, tflite_subgraph, quant_type, sub_graph.get()) != RET_OK) { + MS_LOG(ERROR) << "parse op failed."; return nullptr; } - if (tflite_model->subgraphs.size() != 1) { - MS_LOG(ERROR) << "read tflite model subgraphs failed"; + + // convert tensor + if (ConvertTensor(tflite_subgraph, tflite_model->buffers, sub_graph.get()) != RET_OK) { + MS_LOG(ERROR) << "convert tensor failed"; return nullptr; } - const auto &tflite_subgraph = tflite_model->subgraphs[0]; - // set dst subGraph input/output tensor - TensorCache tensorCache; - SetInputTensor(tflite_subgraph, &tensorCache); + // set graph input/output + if (GetGraphInfo(tflite_subgraph, sub_graph.get()) != RET_OK) { + MS_LOG(ERROR) << "convert tensors failed"; + return nullptr; + } - // set dst subGraph op attr and tensor_cache. - std::unique_ptr subGraph(new schema::MetaGraphT); - subGraph->name = "MS_model converted by TF-Lite"; - auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache, quantType); - if (status != RET_OK) { - MS_LOG(ERROR) << "ParseOp failed."; + // update for depthwiseConv + if (UpdateOp(sub_graph.get()) != RET_OK) { + MS_LOG(ERROR) << "update depthwise conv failed"; return nullptr; } - SetGraphTensorIndex(tflite_subgraph, tflite_model, tensorCache, subGraph.get()); - SetAllTensors(tensorCache, subGraph.get()); - return subGraph.release(); + return sub_graph.release(); } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index ed861bb866..4c8eba0728 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -26,6 +26,7 @@ #include #include #include +#include #include "securec/include/securec.h" #include "tools/converter/model_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -38,44 +39,41 @@ class TfliteModelParser : public ModelParser { public: TfliteModelParser(); - virtual ~TfliteModelParser(); + ~TfliteModelParser() override; - MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile, + MetaGraphT *Parse(const std::string &model_file, + const std::string &weight_file, const QuantType &quantType = QuantType_QUANT_NONE) override; private: - std::unique_ptr ReadTfliteModelFromFlat(const char *buf); + std::unique_ptr ReadTfliteModel(const char *model_path); - void SetMsTensorFromTflite(const std::unique_ptr &tflite_tensor, schema::TensorT *tensor); + STATUS CopyConstTensorData(const std::vector> &tflite_model_buffer, + const tflite::TensorT *tflite_tensor, + schema::TensorT *tensor); - void SetInputTensor(const std::unique_ptr &tflite_subgraph, TensorCache *tensor_cache); + void SetTensorQuantParam(const std::unique_ptr &tflite_tensor, + schema::TensorT *tensor); - void SetGraphTensorIndex(const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_model, - const mindspore::lite::TensorCache &tensorCache, - schema::MetaGraphT *subGraphDef); + STATUS ConvertOp(const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, + const QuantType &quant_type, + schema::MetaGraphT* sub_graph); - STATUS ParseOp(const std::unique_ptr &tflite_model, - const std::unique_ptr &tflite_subgraph, schema::MetaGraphT *sub_graph, - TensorCache *tensor_cache, const QuantType &quantType); + STATUS ConvertTensor(const std::unique_ptr &tflite_subgraph, + const std::vector> &tflite_model_buffer, + schema::MetaGraphT* sub_graph); - STATUS ParseTfliteQuantParams(const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op, schema::CNodeT *op, - TensorCache *tensor_cache); + STATUS GetGraphInfo(const std::unique_ptr &tflite_subgraph, + schema::MetaGraphT* sub_graph); - std::string GetTfliteNodeType(const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model); + STATUS UpdateOp(schema::MetaGraphT* sub_graph); - STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *sub_graph); - - STATUS SetOpOutputIdx(const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op, schema::CNodeT *op, - TensorCache *tensorCache); - - STATUS SetOpInputIdx(const std::unique_ptr &tflite_model, - const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op, schema::CNodeT *op, - TensorCache *tensor_cache); + private: + std::vector tensorsId; + std::vector tensorsFormat; + std::map tensorsIdMap; + std::vector tensors; std::map opMap; std::map tfliteOpMap; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc deleted file mode 100644 index cd24097896..0000000000 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "securec/include/securec.h" -#include "tools/converter/parser/tflite/tflite_node_parser.h" -#include "tools/converter/parser/tflite/tflite_util.h" - -namespace mindspore { -namespace lite { -STATUS TfliteNodeParser::CopyTfliteTensorData(const std::vector> &tfliteModelBuffer, - const tflite::TensorT *tflite_tensor, - schema::TensorT *tensor) { - auto count = 1; - std::for_each(tflite_tensor->shape.begin(), tflite_tensor->shape.end(), [&](int32_t sha) { count *= sha; }); - auto data_size = count * GetDataTypeSize(TypeId(tensor->dataType)); - auto buffer_idx = tflite_tensor->buffer; - if (!tfliteModelBuffer[buffer_idx]->data.empty()) { - tensor->data.resize(data_size); - if (memcpy_s(tensor->data.data(), data_size, tfliteModelBuffer[buffer_idx]->data.data(), data_size)) { - MS_LOG(ERROR) << "memcpy tensor data failed"; - return RET_ERROR; - } - } else { - MS_LOG(ERROR) << "src tensor data is empty"; - return RET_ERROR; - } - return RET_OK; -} - -STATUS TfliteNodeParser::ParseTensor(const std::vector &ts, - const std::vector> &tfliteModelBuffer, - mindspore::lite::TensorCache *tensor_cache, - int node_type, - bool isWeight) { - for (const auto &t : ts) { - auto idx = tensor_cache->FindTensor(t->name); - if (idx < 0) { - std::unique_ptr tensor(new schema::TensorT); - tensor->dataType = GetTfliteDataType(t->type); - tensor->dims = t->shape; - - if (isWeight) { - tensor->format = schema::Format_KHWC; - } else { - tensor->format = schema::Format_NHWC; - } - - if (t->buffer > 0) { - CopyTfliteTensorData(tfliteModelBuffer, t, tensor.get()); - } - - MS_LOG(DEBUG) << "add tensor name: " << t->name.c_str(); - tensor_cache->AddTensor(t->name, tensor.release(), node_type); - } - } - return RET_OK; -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h index de5f0d1b28..c2422e8d90 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h @@ -21,54 +21,84 @@ #include #include #include +#include #include "utils/log_adapter.h" #include "schema/inner/model_generated.h" -#include "tools/converter/parser/tflite/tflite_util.h" #include "tools/converter/parser/tflite/schema_generated.h" #include "tools/common/tensor_util.h" #include "ir/dtype/type_id.h" #include "include/errorcode.h" +#include "tools/converter/parser/tflite/tflite_util.h" namespace mindspore { namespace lite { class TfliteNodeParser { public: - explicit TfliteNodeParser(const std::string &nodeName) : name(nodeName) {} + explicit TfliteNodeParser(const std::string &node_name) : name(node_name) {} virtual ~TfliteNodeParser() = default; - virtual STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, + virtual STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) = 0; + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) = 0; - STATUS ParseTensor(const std::vector &ts, - const std::vector> &tfliteModelBuffer, - mindspore::lite::TensorCache *tensor_cache, - int node_type, - bool isWeight); + void AddOpInput(schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map, + int idx, int new_idx, int total, schema::Format format) { + auto iter = tensors_id_map->find(idx); + if (iter != tensors_id_map->end()) { + op->inputIndex.emplace_back(iter->second); + } else { + if (idx < 0) { + idx += total; + } + tensors_id->emplace_back(idx); + tensors_format->emplace_back(format); + tensors_id_map->insert(std::make_pair(idx, new_idx)); + op->inputIndex.emplace_back(new_idx); + } + } - STATUS CopyTfliteTensorData(const std::vector> &tfliteModelBuffer, - const tflite::TensorT *tflite_tensor, - schema::TensorT *tensor); + void AddOpOutput(schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map, + int idx, int new_idx, int total, schema::Format format) { + auto iter = tensors_id_map->find(idx); + if (iter != tensors_id_map->end()) { + op->outputIndex.emplace_back(iter->second); + } else { + if (idx < 0) { + idx += total; + } + tensors_id->emplace_back(idx); + tensors_format->emplace_back(format); + tensors_id_map->insert(std::make_pair(idx, new_idx)); + op->outputIndex.emplace_back(new_idx); + } + } template - STATUS GetTfliteData(const int32_t tensor_index, const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, + STATUS GetTfliteData(const int32_t tensor_index, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, std::vector &attr_data) { int32_t count = 1; - std::for_each(tfliteTensors[tensor_index]->shape.begin(), tfliteTensors[tensor_index]->shape.end(), + std::for_each(tflite_tensors[tensor_index]->shape.begin(), tflite_tensors[tensor_index]->shape.end(), [&](int32_t sha) { count *= sha; }); - auto &buf_data = tfliteModelBuffer[tfliteTensors[tensor_index]->buffer]; + auto &buf_data = tflite_model_buffer[tflite_tensors[tensor_index]->buffer]; if (buf_data == nullptr) { MS_LOG(ERROR) << "buf_data is null"; return RET_NULL_PTR; } auto data_ptr = buf_data->data.data(); - switch (tfliteTensors[tensor_index]->type) { + switch (tflite_tensors[tensor_index]->type) { case tflite::TensorType_UINT8: { for (int i = 0; i < count; i++) { uint8_t data = *(static_cast(static_cast(data_ptr))); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc index 0da5759332..4f8b726161 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc @@ -14,17 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_one_hot_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_one_hot_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteOneHotParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteOneHotParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteOneHotParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,17 +40,15 @@ STATUS TfliteOneHotParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(INFO) << "parse TfliteOneHotParser"; std::unique_ptr attr(new schema::OneHotT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsOneHotOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsOneHotOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; return RET_NULL_PTR; } - auto axis = tflite_attr->axis; - const auto &tensor = tfliteTensors[tfliteOp->inputs[0]]; + const auto &tensor = tflite_tensors[tflite_op->inputs[0]]; if (tensor == nullptr) { MS_LOG(ERROR) << "tensor is null"; return RET_NULL_PTR; @@ -58,6 +61,13 @@ STATUS TfliteOneHotParser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_OneHot; op->primitive->value.value = attr.release(); + + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h index f21659714a..8a23110957 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_ONE_HOT_PARSER_H -#define PREDICT_TFLITE_ONE_HOT_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ONE_HOT_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ONE_HOT_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,15 @@ class TfliteOneHotParser : public TfliteNodeParser { public: TfliteOneHotParser() : TfliteNodeParser("OneHot") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_ONE_HOT_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ONE_HOT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc index cfd1700547..8f4535b745 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc @@ -14,17 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_pad_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_pad_parser.h" +#include namespace mindspore { namespace lite { -STATUS TflitePadParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TflitePadParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TflitePadParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,9 +40,9 @@ STATUS TflitePadParser::Parse(const std::unique_ptr &tfliteOp return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TflitePadParser"; std::unique_ptr attr(new schema::PadT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsPadOptions(); + + const auto &tflite_attr = tflite_op->builtin_options.AsPadOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; @@ -45,13 +50,18 @@ STATUS TflitePadParser::Parse(const std::unique_ptr &tfliteOp attr->paddingMode = schema::PaddingMode_CONSTANT; attr->constantValue = 0.0f; - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->paddings)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->paddings)) { MS_LOG(ERROR) << "get pad -> paddings failed"; return RET_ERROR; } op->primitive->value.type = schema::PrimitiveType_Pad; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h index e2f0c29c7b..44f657ad4e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_PAD_PARSER_H -#define PREDICT_TFLITE_PAD_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_PAD_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_PAD_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,13 +29,15 @@ class TflitePadParser : public TfliteNodeParser { public: TflitePadParser() : TfliteNodeParser("Pad") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_PAD_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_PAD_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc index f2c51fcd51..1d7db44adb 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc @@ -14,18 +14,21 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_pooling_parser.h" #include #include #include -#include "tools/converter/parser/tflite/tflite_pooling_parser.h" +#include namespace mindspore { namespace lite { STATUS TflitePoolingParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -39,7 +42,7 @@ STATUS TflitePoolingParser::Parse(const std::unique_ptr &tfli std::unique_ptr attr(new schema::PoolingT()); std::vector node_name_str; - Split(op->name.data(), &node_name_str, "-"); + Split(op->name, &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); if (std::strcmp(node_name, "MeanPooling") == 0) { MS_LOG(DEBUG) << "parser TfliteMeanPoolingParser"; @@ -47,9 +50,6 @@ STATUS TflitePoolingParser::Parse(const std::unique_ptr &tfli } else if (std::strcmp(node_name, "MaxPooling") == 0) { MS_LOG(DEBUG) << "parse TfliteMaxPoolingParser"; attr->poolingMode = schema::PoolMode_MAX_POOLING; - } else { - MS_LOG(ERROR) << "wrong pooling type"; - return RET_ERROR; } const auto &tflite_attr = tflite_op->builtin_options.AsPool2DOptions(); @@ -64,41 +64,31 @@ STATUS TflitePoolingParser::Parse(const std::unique_ptr &tfli attr->padMode = GetPadMode(tflite_attr->padding); attr->format = schema::Format_NHWC; - // by default attr->global = false; attr->roundMode = schema::RoundMode_FLOOR; // calculate pad params - if (attr->padMode == schema::PadMode_VALID || attr->padMode == schema::PadMode_NOTSET) { - attr->padUp = 0; - attr->padDown = 0; - attr->padLeft = 0; - attr->padRight = 0; - } else if (attr->padMode == schema::PadMode_SAME) { - auto data_index = tflite_op->inputs[0]; - const auto &data_tensor = tfliteTensors[data_index]; - if (data_tensor == nullptr) { - MS_LOG(ERROR) << "the first input is null"; - return RET_NULL_PTR; - } - - auto shape = data_tensor->shape; - int H_input = shape.at(1); - int W_input = shape.at(2); - - int H_output = ceil(H_input / attr->strideH); - int pad_needed_H = (H_output - 1) * attr->strideH + attr->windowH - H_input; - attr->padUp = floor(pad_needed_H / 2.0); - attr->padDown = pad_needed_H - attr->padUp; - - int W_output = ceil(W_input / attr->strideW); - int pad_needed_W = (W_output - 1) * attr->strideW + attr->windowW - W_input; - attr->padLeft = floor(pad_needed_W / 2.0); - attr->padRight = pad_needed_W - attr->padLeft; + auto data_index = tflite_op->inputs[0]; + const auto &data_tensor = tflite_tensors[data_index]; + std::vector params; + if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, + attr->strideW, attr->windowH, attr->windowW, ¶ms) != RET_OK) { + MS_LOG(ERROR) << "get padding params failed"; + return RET_ERROR; + } else { + attr->padUp = params.at(0); + attr->padDown = params.at(1); + attr->padLeft = params.at(2); + attr->padRight = params.at(3); } op->primitive->value.type = schema::PrimitiveType_Pooling; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h index 1129df01cd..fe8d5fa804 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_MEAN_POOLING_PARSER_H -#define PREDICT_TFLITE_MEAN_POOLING_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MEAN_POOLING_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MEAN_POOLING_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,12 +29,13 @@ class TflitePoolingParser : public TfliteNodeParser { public: TflitePoolingParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteMeanPoolingParser : public TflitePoolingParser { @@ -48,5 +50,5 @@ class TfliteMaxPoolingParser : public TflitePoolingParser { } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_CONV_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc index 38216a911e..519b5c243b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc @@ -17,16 +17,19 @@ #include "tools/converter/parser/tflite/tflite_range_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteRangeParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { +STATUS TfliteRangeParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteRangeParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,13 +40,20 @@ STATUS TfliteRangeParser::Parse(const std::unique_ptr &tflite return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteRangeParser"; std::unique_ptr attr(new schema::RangeT()); attr->dType = 0; +// attr->start +// attr->limit +// attr->delta op->primitive->value.type = schema::PrimitiveType_Range; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h index 2701590151..204b5c8e73 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_RANGE_PARSER_H -#define PREDICT_TFLITE_RANGE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RANGE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RANGE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteRangeParser : public TfliteNodeParser { public: TfliteRangeParser() : TfliteNodeParser("Range") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_RANGE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RANGE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc index d66b5278dc..d4fcac5371 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc @@ -17,16 +17,19 @@ #include "tools/converter/parser/tflite/tflite_rank_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteRankParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { +STATUS TfliteRankParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteRankParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,11 +40,15 @@ STATUS TfliteRankParser::Parse(const std::unique_ptr &tfliteO return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteRankParser"; std::unique_ptr attr(new schema::RankT()); op->primitive->value.type = schema::PrimitiveType_Rank; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h index 11257b6f2b..9afd1fcc22 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_RANK_PARSER_H -#define PREDICT_TFLITE_RANK_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RANK_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RANK_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteRankParser : public TfliteNodeParser { public: TfliteRankParser() : TfliteNodeParser("Rank") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_RANK_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RANK_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc index 4472430e48..9cfbceed32 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc @@ -14,18 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_reduce_parser.h" #include #include #include -#include "tools/converter/parser/tflite/tflite_reduce_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteReduceParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteReduceParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,8 +41,9 @@ STATUS TfliteReduceParser::Parse(const std::unique_ptr &tflit } std::unique_ptr attr(new schema::ReduceT()); + // auto tflite_tensors = tflite_subgraph->tensors; - const auto &tflite_attr = tfliteOp->builtin_options.AsReducerOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsReducerOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; return RET_NULL_PTR; @@ -46,8 +51,9 @@ STATUS TfliteReduceParser::Parse(const std::unique_ptr &tflit attr->keepDims = tflite_attr->keep_dims; std::vector node_name_str; - Split(op->name.data(), &node_name_str, "-"); + Split(op->name, &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); + if (std::strcmp(node_name, "ReduceMax") == 0) { MS_LOG(DEBUG) << "parse TfliteReduceMaxParser"; attr->mode = schema::ReduceMode_ReduceMax; @@ -67,18 +73,20 @@ STATUS TfliteReduceParser::Parse(const std::unique_ptr &tflit // attr->mode; MS_LOG(ERROR) << "ms-lite haven't supported REDUCE_ANY now"; return RET_NOT_FIND_OP; - } else { - MS_LOG(ERROR) << "wrong reduce type"; - return RET_ERROR; } - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->axes)) { - MS_LOG(ERROR) << "get reduce_prod -> axes failed"; + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->axes)) { + MS_LOG(ERROR) << "get reduce -> axes failed"; return RET_ERROR; } op->primitive->value.type = schema::PrimitiveType_Reduce; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h index 628243881c..7960143948 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_REDUCE_MAX_PARSER_H -#define PREDICT_TFLITE_REDUCE_MAX_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REDUCE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REDUCE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,12 +29,13 @@ class TfliteReduceParser : public TfliteNodeParser { public: TfliteReduceParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteReduceMaxParser : public TfliteReduceParser { @@ -69,4 +71,4 @@ class TfliteReduceAnyParser : public TfliteReduceParser { } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_REDUCE_MAX_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REDUCE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc index c67b113c37..7bdc2d8318 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc @@ -14,17 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_reshape_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_reshape_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteReshapeParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteReshapeParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteReshapeParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,23 +40,22 @@ STATUS TfliteReshapeParser::Parse(const std::unique_ptr &tfli return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteReshapeParser"; std::unique_ptr attr(new schema::ReshapeT()); - const auto &tfliteAttr = tfliteOp->builtin_options.AsReshapeOptions(); + const auto &tfliteAttr = tflite_op->builtin_options.AsReshapeOptions(); if (tfliteAttr == nullptr) { - if (tfliteOp->inputs.size() < 2) { - MS_LOG(ERROR) << "expected two input tensors, but got: " << tfliteOp->inputs.size(); + if (tflite_op->inputs.size() < 2) { + MS_LOG(ERROR) << "expected two input tensors, but got: " << tflite_op->inputs.size(); return RET_ERROR; } - auto shape_tensor_index = tfliteOp->inputs[1]; - const auto & shape_tensor = tfliteTensors[shape_tensor_index]; + auto shape_tensor_index = tflite_op->inputs[1]; + const auto & shape_tensor = tflite_tensors[shape_tensor_index]; if (shape_tensor == nullptr) { MS_LOG(ERROR) << "shape_tensor is null"; return RET_NULL_PTR; } - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->shape)) { - MS_LOG(ERROR) << "get reshape->shape error"; + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->shape)) { + MS_LOG(ERROR) << "get reshape -> shape failed"; return RET_ERROR; } } else { @@ -64,6 +68,13 @@ STATUS TfliteReshapeParser::Parse(const std::unique_ptr &tfli op->primitive->value.type = schema::PrimitiveType_Reshape; op->primitive->value.value = attr.release(); + + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h index a122d9512f..6ab5fa1db6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_RESHAPE_PARSER_H -#define PREDICT_TFLITE_RESHAPE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESHAPE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESHAPE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,16 @@ class TfliteReshapeParser : public TfliteNodeParser { public: TfliteReshapeParser() : TfliteNodeParser("Reshape") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_ADD_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADD_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc index 4dfc27c32e..4bb37bb93c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc @@ -14,18 +14,21 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_resize_parser.h" #include #include #include -#include "tools/converter/parser/tflite/tflite_resize_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteResizeParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteResizeParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -41,10 +44,9 @@ STATUS TfliteResizeParser::Parse(const std::unique_ptr &tflit std::vector node_name_str; Split(op->name.data(), &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); - if (std::strcmp(node_name, "ResizeBilinear") == 0) { MS_LOG(DEBUG) << "parse TfliteResizeBilinearParser"; - const auto &tfliteAttr = tfliteOp->builtin_options.AsResizeBilinearOptions(); + const auto &tfliteAttr = tflite_op->builtin_options.AsResizeBilinearOptions(); if (tfliteAttr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; @@ -53,7 +55,7 @@ STATUS TfliteResizeParser::Parse(const std::unique_ptr &tflit attr->method = schema::ResizeMethod_BILINEAR; } else if (std::strcmp(node_name, "NearestNeighbor") == 0) { MS_LOG(DEBUG) << "parse TfliteResizeNearestNeighborParser"; - const auto &tfliteAttr = tfliteOp->builtin_options.AsResizeNearestNeighborOptions(); + const auto &tfliteAttr = tflite_op->builtin_options.AsResizeNearestNeighborOptions(); if (tfliteAttr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; @@ -68,14 +70,14 @@ STATUS TfliteResizeParser::Parse(const std::unique_ptr &tflit attr->format = schema::Format_NHWC; attr->preserveAspectRatio = false; - auto tfliteResizeTensorIndex = tfliteOp->inputs[1]; - const auto & shape_tensor = tfliteTensors[tfliteResizeTensorIndex]; + auto tfliteResizeTensorIndex = tflite_op->inputs[1]; + const auto & shape_tensor = tflite_tensors[tfliteResizeTensorIndex]; if (shape_tensor == nullptr) { MS_LOG(ERROR) << "shape_tensor is null"; return RET_NULL_PTR; } auto resizeTensorBufferIndex = shape_tensor->buffer; - const auto & buff = tfliteModelBuffer.at(resizeTensorBufferIndex); + const auto & buff = tflite_model_buffer.at(resizeTensorBufferIndex); if (buff == nullptr) { MS_LOG(ERROR) << "buff_data is null"; return RET_NULL_PTR; @@ -88,6 +90,11 @@ STATUS TfliteResizeParser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_Resize; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h index 779a1cf0cd..14245ba3e9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_RESIZE_PARSER_H -#define PREDICT_TFLITE_RESIZE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESIZE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESIZE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,11 +29,13 @@ class TfliteResizeParser : public TfliteNodeParser { public: TfliteResizeParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteResizeBilinearParser : public TfliteResizeParser { @@ -48,5 +51,5 @@ class TfliteResizeNearestNeighborParser : public TfliteResizeParser { } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_RESIZE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESIZE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc index a1ddd7f13c..7e449a3588 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc @@ -17,16 +17,19 @@ #include "tools/converter/parser/tflite/tflite_reverse_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteReverseParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { +STATUS TfliteReverseParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteReverseParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,15 +40,20 @@ STATUS TfliteReverseParser::Parse(const std::unique_ptr &tfli return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteReverseParser"; std::unique_ptr attr(new schema::ReverseT()); - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->axis)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->axis)) { + MS_LOG(ERROR) << "get reverse -> axis failed"; return RET_ERROR; } op->primitive->value.type = schema::PrimitiveType_Reverse; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h index 3965ab1ece..d9fa0ce2df 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_REVERSE_PARSER_H -#define PREDICT_TFLITE_REVERSE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REVERSE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REVERSE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteReverseParser : public TfliteNodeParser { public: TfliteReverseParser() : TfliteNodeParser("reverse") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_REVERSE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REVERSE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc index 867ca6bc77..3d996a4fef 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc @@ -15,17 +15,23 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_reverse_sequence_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_reverse_sequence_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteReverseSequenceParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) { + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteReverseSequenceParser"; + + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -36,7 +42,6 @@ STATUS TfliteReverseSequenceParser::Parse(const std::unique_ptr attr(new schema::ReverseSequenceT()); const auto &tflite_attr = tflite_op->builtin_options.AsReverseSequenceOptions(); @@ -54,6 +59,11 @@ STATUS TfliteReverseSequenceParser::Parse(const std::unique_ptrprimitive->value.type = schema::PrimitiveType_ReverseSequence; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h index 20cac753e1..927247fe86 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_REVERSE_SEQUENCE_PARSER_H -#define LITE_TFLITE_REVERSE_SEQUENCE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REVERSE_SEQUENCE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REVERSE_SEQUENCE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteReverseSequenceParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_REVERSE_SEQUENCE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REVERSE_SEQUENCE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc index f1dac84959..e86e7c7a75 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc @@ -14,18 +14,23 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_scatter_nd_parser.h" #include #include #include -#include "tools/converter/parser/tflite/tflite_scatter_nd_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteScatterNdParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteScatterNdParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteScatterNdParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -36,33 +41,26 @@ STATUS TfliteScatterNdParser::Parse(const std::unique_ptr &tf return RET_NULL_PTR; } - MS_LOG(INFO) << "parse TfliteScatterNdParser"; std::unique_ptr attr(new schema::ScatterNDT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsScatterNdOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsScatterNdOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; return RET_NULL_PTR; } - /* - MS_LOG(DEBUG) << "op->inputIndex"; - for (auto &i : op->inputIndex) { - MS_LOG(DEBUG) << i; - } - */ - // in tflite, kIndices = 0, kUpdates = 1, kShape = 2 - // in mslite, kScatterShapeIndex = 0, kScatterIndicesIndex = 1, kScatterUpdateIndex = 2; - std::swap(op->inputIndex[0], op->inputIndex[2]); - std::swap(op->inputIndex[1], op->inputIndex[2]); - /* - MS_LOG(DEBUG) << "op->inputIndex after resort"; - for (auto &i : op->inputIndex) { - MS_LOG(DEBUG) << i; - } - */ - op->primitive->value.type = schema::PrimitiveType_ScatterND; op->primitive->value.value = attr.release(); + + // in tflite, kIndices = 0, kUpdates = 1, kShape = 2 + // in mslite, kScatterShapeIndex = 0, kScatterIndicesIndex = 1, kScatterUpdateIndex = 2; + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h index 3823296885..6baeed21be 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_SCATTER_ND_PARSER_H -#define PREDICT_TFLITE_SCATTER_ND_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SCATTER_ND_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SCATTER_ND_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,15 @@ class TfliteScatterNdParser : public TfliteNodeParser { public: TfliteScatterNdParser() : TfliteNodeParser("ScatterNd") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_SCATTER_ND_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SCATTER_ND_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc index 95163afde1..eae6cdb352 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc @@ -14,17 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_shape_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_shape_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteShapeParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteShapeParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteShapeParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,11 +40,15 @@ STATUS TfliteShapeParser::Parse(const std::unique_ptr &tflite return RET_NULL_PTR; } - MS_LOG(INFO) << "parse TfliteShapeParser"; std::unique_ptr attr(new schema::ShapeT()); op->primitive->value.type = schema::PrimitiveType_Shape; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h index b0f0fee85c..ab34d3c901 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_SHAPE_PARSER_H -#define PREDICT_TFLITE_SHAPE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SHAPE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SHAPE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,15 @@ class TfliteShapeParser : public TfliteNodeParser { public: TfliteShapeParser() : TfliteNodeParser("Shape") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_SHAPE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SHAPE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc index a18e624380..14d3106723 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc @@ -17,15 +17,20 @@ #include "tools/converter/parser/tflite/tflite_slice_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteSliceParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteSliceParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteSliceParser"; + + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -36,20 +41,26 @@ STATUS TfliteSliceParser::Parse(const std::unique_ptr &tflite return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteSliceParser"; std::unique_ptr attr(new schema::SliceT()); - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->begin)) { + attr->format = schema::Format_NHWC; + + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->begin)) { MS_LOG(ERROR) << "get slice -> begin failed"; return RET_ERROR; } - if (GetTfliteData(tfliteOp->inputs[2], tfliteTensors, tfliteModelBuffer, attr->size)) { + if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->size)) { MS_LOG(ERROR) << "get slice -> size failed"; return RET_ERROR; } op->primitive->value.type = schema::PrimitiveType_Slice; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h index 70c1b96da7..0a84cda642 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_SLICE_PARSER_H -#define PREDICT_TFLITE_SLICE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SLICE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SLICE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,16 @@ class TfliteSliceParser : public TfliteNodeParser { public: TfliteSliceParser() : TfliteNodeParser("Slice") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_SLICE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SLICE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc index 87277a5872..5c25d4837a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc @@ -14,17 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_softmax_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_softmax_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteSoftmaxParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteSoftmaxParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteSoftmaxParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,19 +40,17 @@ STATUS TfliteSoftmaxParser::Parse(const std::unique_ptr &tfli return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteSoftmaxParser"; std::unique_ptr attr(new schema::SoftMaxT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsSoftmaxOptions(); - if (tflite_attr == nullptr) { - MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; - return RET_NULL_PTR; - } - attr->axis = -1; op->primitive->value.type = schema::PrimitiveType_SoftMax; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h index 685898c429..73576c110c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_CONV_PARSER_H -#define PREDICT_TFLITE_CONV_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,16 @@ class TfliteSoftmaxParser : public TfliteNodeParser { public: TfliteSoftmaxParser() : TfliteNodeParser("Softmax") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_CONV_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc index 9e3fd7db39..bfd9e8248f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc @@ -15,18 +15,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteSpaceToBatchNDParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteSpaceToBatchNDParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,7 +41,6 @@ STATUS TfliteSpaceToBatchNDParser::Parse(const std::unique_ptr attr(new schema::SpaceToBatchNDT()); if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->blockShape)) { @@ -51,6 +54,11 @@ STATUS TfliteSpaceToBatchNDParser::Parse(const std::unique_ptrprimitive->value.type = schema::PrimitiveType_SpaceToBatchND; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h index 287f492bc6..e8b5b69a11 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_SPACE_TO_BATCH_ND_PARSER_H -#define LITE_TFLITE_SPACE_TO_BATCH_ND_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPACE_TO_BATCH_ND_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPACE_TO_BATCH_ND_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteSpaceToBatchNDParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_SPACE_TO_BATCH_ND_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPACE_TO_BATCH_ND_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc index 19d33c1bfa..b9dca6927d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc @@ -15,18 +15,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_space_to_depth_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_space_to_depth_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteSpaceToDepthParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteSpaceToDepthParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,7 +41,6 @@ STATUS TfliteSpaceToDepthParser::Parse(const std::unique_ptr return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteSpaceToDepthParser"; std::unique_ptr attr(new schema::SpaceToDepthT()); const auto &tflite_attr = tflite_op->builtin_options.AsSpaceToDepthOptions(); @@ -46,11 +49,15 @@ STATUS TfliteSpaceToDepthParser::Parse(const std::unique_ptr return RET_NULL_PTR; } attr->blockSize = tflite_attr->block_size; - attr->format = schema::Format_NHWC; op->primitive->value.type = schema::PrimitiveType_SpaceToDepth; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h index 3adf534253..be2cc7a16c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_SPACE_TO_DEPTH_PARSER_H -#define LITE_TFLITE_SPACE_TO_DEPTH_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPACE_TO_DEPTH_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPACE_TO_DEPTH_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteSpaceToDepthParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_SPACE_TO_DEPTH_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPACE_TO_DEPTH_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc index 4a98f90d8f..8859377b96 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc @@ -15,18 +15,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteSparseToDenseParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteSparseToDenseParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,7 +41,6 @@ STATUS TfliteSparseToDenseParser::Parse(const std::unique_ptr return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteSparseToDenseParser"; std::unique_ptr attr(new schema::SparseToDenseT()); attr->validateIndices = false; @@ -57,6 +60,11 @@ STATUS TfliteSparseToDenseParser::Parse(const std::unique_ptr op->primitive->value.type = schema::PrimitiveType_SparseToDense; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h index ad4a6e02ce..f4c496f5bf 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_SPARSE_TO_DENSE_PARSER_H -#define LITE_TFLITE_SPARSE_TO_DENSE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPARSE_TO_DENSE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPARSE_TO_DENSE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteSparseToDenseParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_SPARSE_TO_DENSE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPARSE_TO_DENSE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc index 9e90bbfa03..cf3695eeb1 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc @@ -14,19 +14,23 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_split_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_split_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteSplitParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, +STATUS TfliteSplitParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteSplitParser"; + + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,33 +41,32 @@ STATUS TfliteSplitParser::Parse(const std::unique_ptr &tflite return RET_NULL_PTR; } - MS_LOG(INFO) << "parse TfliteSplitParser"; std::unique_ptr attr(new schema::SplitT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsSplitOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsSplitOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; return RET_NULL_PTR; } auto num_splits = tflite_attr->num_splits; - const auto &shape_tensor = tfliteTensors[tfliteOp->inputs[1]]; + const auto &shape_tensor = tflite_tensors[tflite_op->inputs[1]]; if (shape_tensor == nullptr) { MS_LOG(ERROR) << "shape_tensor is null"; return RET_NULL_PTR; } const auto tensor_shape = shape_tensor->shape; - const auto &axis_tensor = tfliteTensors[tfliteOp->inputs[0]]; + const auto &axis_tensor = tflite_tensors[tflite_op->inputs[0]]; if (axis_tensor == nullptr) { MS_LOG(ERROR) << "axis_tensor is null"; return RET_NULL_PTR; } - auto axis = *(reinterpret_cast(tfliteModelBuffer[axis_tensor->buffer]->data.data())); + auto axis = *(reinterpret_cast(tflite_model_buffer[axis_tensor->buffer]->data.data())); if (axis < 0) { axis += tensor_shape.size(); } if (axis >= tensor_shape.size()) { - MS_LOG(ERROR) << "axis value too large"; + MS_LOG(ERROR) << "axis value is too large"; return RET_ERROR; } attr->splitDim = axis; @@ -79,6 +82,13 @@ STATUS TfliteSplitParser::Parse(const std::unique_ptr &tflite op->primitive->value.type = schema::PrimitiveType_Split; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + for (int i = 0; i < tflite_op->outputs.size(); i++) { + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h index 39210e5086..997fbb01e2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_SPLIT_PARSER_H -#define PREDICT_TFLITE_SPLIT_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPLIT_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPLIT_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,15 @@ class TfliteSplitParser : public TfliteNodeParser { public: TfliteSplitParser() : TfliteNodeParser("Split") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_SPLIT_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPLIT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc index 0fdf4b3a3f..f134e085cd 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc @@ -14,17 +14,20 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_split_v_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_split_v_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteSplitVParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteSplitVParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,44 +38,51 @@ STATUS TfliteSplitVParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(INFO) << "parse TfliteSplitVParser"; + MS_LOG(DEBUG) << "parse TfliteSplitVParser"; std::unique_ptr attr(new schema::SplitT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsSplitVOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsSplitVOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; return RET_NULL_PTR; } attr->numberSplit = tflite_attr->num_splits; - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->sizeSplits)) { - MS_LOG(ERROR) << "get splite_v -> sizeSplits failed"; + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->sizeSplits)) { + MS_LOG(ERROR) << "get spliteV -> sizeSplits failed"; return RET_ERROR; } - const auto &tensor = tfliteTensors[tfliteOp->inputs[0]]; + const auto &tensor = tflite_tensors[tflite_op->inputs[0]]; if (tensor == nullptr) { MS_LOG(ERROR) << "tensor_shape is null"; return RET_NULL_PTR; } auto tensor_shape = tensor->shape; - const auto &axis_tensor = tfliteTensors[tfliteOp->inputs[2]]; + const auto &axis_tensor = tflite_tensors[tflite_op->inputs[2]]; if (axis_tensor == nullptr) { MS_LOG(ERROR) << "axis_tensor is null"; return RET_NULL_PTR; } - auto axis = *(reinterpret_cast(tfliteModelBuffer[axis_tensor->buffer]->data.data())); + auto axis = *(reinterpret_cast(tflite_model_buffer[axis_tensor->buffer]->data.data())); if (axis < 0) { axis += tensor_shape.size(); } if (axis >= tensor_shape.size()) { - MS_LOG(ERROR) << "axis value too large"; + MS_LOG(ERROR) << "axis value is too large"; return RET_ERROR; } attr->splitDim = axis; op->primitive->value.type = schema::PrimitiveType_Split; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + for (int i = 0; i < tflite_op->outputs.size(); i++) { + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h index c3eefcdec2..125ddbc30d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_SPLIT_V_PARSER_H -#define PREDICT_TFLITE_SPLIT_V_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPLIT_V_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPLIT_V_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,15 @@ class TfliteSplitVParser : public TfliteNodeParser { public: TfliteSplitVParser() : TfliteNodeParser("SplitV") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_SPLIT_V_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPLIT_V_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc index bd6ee5f641..a5720e4a08 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc @@ -14,17 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_squeeze_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_squeeze_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteSqueezeParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteSqueezeParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteSqueezeParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,10 +40,9 @@ STATUS TfliteSqueezeParser::Parse(const std::unique_ptr &tfli return RET_NULL_PTR; } - MS_LOG(INFO) << "parse TfliteSqueezeParser"; std::unique_ptr attr(new schema::SqueezeT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsSqueezeOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsSqueezeOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; @@ -47,6 +51,11 @@ STATUS TfliteSqueezeParser::Parse(const std::unique_ptr &tfli op->primitive->value.type = schema::PrimitiveType_Squeeze; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h index 7738773856..e1b9c9436e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_SQUEEZE_PARSER_H -#define PREDICT_TFLITE_SQUEEZE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SQUEEZE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SQUEEZE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,15 @@ class TfliteSqueezeParser : public TfliteNodeParser { public: TfliteSqueezeParser() : TfliteNodeParser("Squeeze") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_SQUEEZE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SQUEEZE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc index b6b7da18b1..63d4f0fbfb 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc @@ -14,17 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_stack_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_stack_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteStackParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteStackParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteStackParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,21 +40,26 @@ STATUS TfliteStackParser::Parse(const std::unique_ptr &tflite return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteStackParser"; std::unique_ptr attr(new schema::StackT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsPackOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsPackOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } - attr->axis = tflite_attr->axis; attr->n = tflite_attr->values_count; - attr->isScale.assign(tfliteTensors[tfliteOp->inputs[0]]->shape.begin(), - tfliteTensors[tfliteOp->inputs[0]]->shape.end()); + attr->isScale.assign(tflite_tensors[tflite_op->inputs[0]]->shape.begin(), + tflite_tensors[tflite_op->inputs[0]]->shape.end()); op->primitive->value.type = schema::PrimitiveType_Stack; op->primitive->value.value = attr.release(); + + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h index db85b07828..3e6774239a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_STACK_PARSER_H -#define PREDICT_TFLITE_STACK_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_STACK_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_STACK_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,16 @@ class TfliteStackParser : public TfliteNodeParser { public: TfliteStackParser() : TfliteNodeParser("Stack") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_STACK_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_STACK_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc index 04f33beb16..123e458665 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc @@ -14,18 +14,20 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_strided_slice_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_strided_slice_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteStridedSliceParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -43,7 +45,6 @@ STATUS TfliteStridedSliceParser::Parse(const std::unique_ptr MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); return RET_NULL_PTR; } - attr->beginMask = tflite_attr->begin_mask; attr->endMask = tflite_attr->end_mask; attr->ellipsisMask = tflite_attr->ellipsis_mask; @@ -67,6 +68,11 @@ STATUS TfliteStridedSliceParser::Parse(const std::unique_ptr op->primitive->value.type = schema::PrimitiveType_StridedSlice; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h index 4a2b1814db..9e4db2461f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h @@ -15,11 +15,12 @@ */ -#ifndef LITE_TFLITE_STRIDED_SLICE_PARSER_H -#define LITE_TFLITE_STRIDED_SLICE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_STRIDED_SLICE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_STRIDED_SLICE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -32,12 +33,13 @@ class TfliteStridedSliceParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_STRIDED_SLICE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_STRIDED_SLICE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc index c73477ef3f..6b618829b0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc @@ -15,18 +15,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_tile_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_tile_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteTileParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteTileParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,7 +41,6 @@ STATUS TfliteTileParser::Parse(const std::unique_ptr &tflite_ return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteTileParser"; std::unique_ptr attr(new schema::TileT()); if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->multiples)) { @@ -47,6 +50,11 @@ STATUS TfliteTileParser::Parse(const std::unique_ptr &tflite_ op->primitive->value.type = schema::PrimitiveType_Tile; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h index 48ba12053a..3534901911 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_TILE_PARSER_H -#define LITE_TFLITE_TILE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TILE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TILE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteTileParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_TILE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TILE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc index cd55c852b4..a89a641d13 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc @@ -15,18 +15,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_topk_v2_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_topk_v2_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteTopKV2Parser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteTopKV2Parser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,9 +41,9 @@ STATUS TfliteTopKV2Parser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteTopKV2Parser"; std::unique_ptr attr(new schema::TopKV2T()); + attr->sorted = true; if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->k)) { MS_LOG(ERROR) << "get topKV2 -> k failed"; return RET_ERROR; @@ -47,6 +51,11 @@ STATUS TfliteTopKV2Parser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_TopKV2; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h index 3bf9c1dabf..6ed92506c1 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_TOPK_V2_PARSER_H -#define LITE_TFLITE_TOPK_V2_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TOPK_V2_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TOPK_V2_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteTopKV2Parser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_TOPK_V2_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TOPK_V2_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc index 5f3d90b889..759c3d8fc5 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc @@ -14,17 +14,21 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_transpose_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_transpose_parser.h" namespace mindspore { namespace lite { -STATUS TfliteTransposeParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteTransposeParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteTransposeParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,28 +39,23 @@ STATUS TfliteTransposeParser::Parse(const std::unique_ptr &tf return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteTransposeParser"; std::unique_ptr attr(new schema::TransposeT()); - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->perm)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->perm)) { MS_LOG(ERROR) << "get transpose -> perm failed"; return RET_ERROR; } - auto weight_index = tfliteOp->inputs[1]; - const auto &weight_tensor = tfliteTensors[weight_index]; - if (weight_tensor == nullptr) { - MS_LOG(ERROR) << "weight_tensor is null"; - return RET_ERROR; - } - std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse weight failed"; - return RET_ERROR; - } - + attr->conjugate = false; op->primitive->value.type = schema::PrimitiveType_Transpose; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h index f92eed0c54..4fc4062713 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_TRANSPOSE_PARSER_H -#define PREDICT_TFLITE_TRANSPOSE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TRANSPOSE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TRANSPOSE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,16 @@ class TfliteTransposeParser : public TfliteNodeParser { public: TfliteTransposeParser() : TfliteNodeParser("Transpose") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_TRANSPOSE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TRANSPOSE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc index 6860c01bab..aa052c250d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc @@ -15,18 +15,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_unique_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_unique_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteUniqueParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteUniqueParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,7 +41,6 @@ STATUS TfliteUniqueParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteUniqueParser"; std::unique_ptr attr(new schema::UniqueT()); const auto &tflite_attr = tflite_op->builtin_options.AsUniqueOptions(); @@ -45,11 +48,17 @@ STATUS TfliteUniqueParser::Parse(const std::unique_ptr &tflit MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); return RET_NULL_PTR; } - - attr->outType = dtype_map[tflite_attr->idx_out_type]; + attr->outType = GetTfliteDataType(tflite_attr->idx_out_type); op->primitive->value.type = schema::PrimitiveType_Unique; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + for (int i = 0; i < tflite_op->outputs.size(); i++) { + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h index 331aa5e48f..2fadd9aa2f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_UNIQUE_PARSER_H -#define LITE_TFLITE_UNIQUE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_UNIQUE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_UNIQUE_PARSER_H #include #include @@ -32,24 +32,12 @@ class TfliteUniqueParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; - - private: - std::map dtype_map = { - {tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64}, - {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, - {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, - {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, - {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, - {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, - {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, - {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, - {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, - }; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_UNIQUE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_UNIQUE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc index a2ea8b5a6b..b78f4fe799 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc @@ -15,17 +15,23 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_unstack_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_unstack_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteUnstackParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) { + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "paser TfliteUnstackParser"; + + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -36,7 +42,6 @@ STATUS TfliteUnstackParser::Parse(const std::unique_ptr &tfli return RET_NULL_PTR; } - MS_LOG(DEBUG) << "paser TfliteUnstackParser"; std::unique_ptr attr(new schema::UnstackT()); const auto &tflite_attr = tflite_op->builtin_options.AsUnpackOptions(); @@ -49,6 +54,13 @@ STATUS TfliteUnstackParser::Parse(const std::unique_ptr &tfli op->primitive->value.type = schema::PrimitiveType_Unstack; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + for (int i = 0; i < tflite_op->outputs.size(); i++) { + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h index 82729e7f38..28fed8b714 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_UNSTACK_PARSER_H -#define LITE_TFLITE_UNSTACK_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_UNSTACK_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_UNSTACK_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteUnstackParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_UNSTACK_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_UNSTACK_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc index d8358a8e5e..3a4f1b8a9e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc @@ -15,25 +15,15 @@ */ #include "tools/converter/parser/tflite/tflite_util.h" -#include #include #include +#include +#include #include "utils/log_adapter.h" #include "include/errorcode.h" namespace mindspore { namespace lite { -std::map tfMsActivationFunctionMap{ - {tflite::ActivationFunctionType_NONE, schema::ActivationType_NO_ACTIVATION}, - {tflite::ActivationFunctionType_RELU, schema::ActivationType_RELU}, - {tflite::ActivationFunctionType_RELU6, schema::ActivationType_RELU6}, - {tflite::ActivationFunctionType_TANH, schema::ActivationType_TANH}, -}; - -schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType) { - return tfMsActivationFunctionMap.at(tfliteAFType); -} - std::map tfMsOpTypeMap{ {tflite::BuiltinOperator_CONV_2D, "Conv2D"}, {tflite::BuiltinOperator_DEPTHWISE_CONV_2D, "DepthwiseConv2D"}, @@ -129,25 +119,38 @@ std::map tfMsOpTypeMap{ {tflite::BuiltinOperator_UNPACK, "Unstack"}, }; -std::string GetMSOpType(tflite::BuiltinOperator tfliteOpType) { - auto iter = tfMsOpTypeMap.find(tfliteOpType); - if (iter == tfMsOpTypeMap.end()) { - // return "unsupported_op_type"; - return tflite::EnumNameBuiltinOperator(tfliteOpType); - } - return iter->second; -} +std::map tfMsActivationFunctionMap{ + {tflite::ActivationFunctionType_NONE, schema::ActivationType_NO_ACTIVATION}, + {tflite::ActivationFunctionType_RELU, schema::ActivationType_RELU}, + {tflite::ActivationFunctionType_RELU6, schema::ActivationType_RELU6}, + {tflite::ActivationFunctionType_TANH, schema::ActivationType_TANH}, +}; std::map type_map = { + {tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64}, {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, - {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, + {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, + {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, }; +schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType) { + return tfMsActivationFunctionMap.at(tfliteAFType); +} + +std::string GetMSOpType(tflite::BuiltinOperator tfliteOpType) { + auto iter = tfMsOpTypeMap.find(tfliteOpType); + if (iter == tfMsOpTypeMap.end()) { + // return "unsupported_op_type"; + return tflite::EnumNameBuiltinOperator(tfliteOpType); + } + return iter->second; +} + TypeId GetTfliteDataType(const tflite::TensorType &tflite_data_type) { auto iter = type_map.find(tflite_data_type); if (iter == type_map.end()) { @@ -183,12 +186,48 @@ size_t GetDataTypeSize(const TypeId &data_type) { case TypeId::kNumberTypeInt64: return sizeof(int64_t); default: - MS_LOG(ERROR) << data_type; - MS_LOG(ERROR) << "Unsupported datatype"; + MS_LOG(ERROR) << data_type << " is Unsupported datatype"; return RET_ERROR; } } +STATUS getPaddingParam(const std::unique_ptr &tensor, + schema::PadMode pad_mode, + int strideH, int strideW, + int windowH, int windowW, + std::vector *params) { + if (tensor == nullptr) { + MS_LOG(ERROR) << "the input tensor is null"; + return RET_ERROR; + } + + int padUp = 0; + int padDown = 0; + int padLeft = 0; + int padRight = 0; + if (pad_mode == schema::PadMode_SAME) { + auto shape = tensor->shape; + int H_input = shape.at(1); + int W_input = shape.at(2); + + int H_output = ceil(H_input * 1.0 / strideH); + int pad_needed_H = (H_output - 1) * strideH + windowH - H_input; + padUp = floor(pad_needed_H / 2.0); + padDown = pad_needed_H - padUp; + + int W_output = ceil(W_input * 1.0 / strideW); + int pad_needed_W = (W_output - 1) * strideW + windowW - W_input; + padLeft = floor(pad_needed_W / 2.0); + padRight = pad_needed_W - padLeft; + } + + params->emplace_back(padUp); + params->emplace_back(padDown); + params->emplace_back(padLeft); + params->emplace_back(padRight); + return RET_OK; +} + void Split(const std::string &src_str, std::vector *dst_str, const std::string &chr) { std::string ::size_type p1 = 0, p2 = src_str.find(chr); while (std::string::npos != p2) { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.h b/mindspore/lite/tools/converter/parser/tflite/tflite_util.h index ee78e96a36..9dc0bba97d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_util.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.h @@ -19,11 +19,14 @@ #include #include +#include +#include #include "utils/log_adapter.h" #include "schema/inner/model_generated.h" #include "tools/converter/parser/tflite/schema_generated.h" #include "schema/inner/ops_generated.h" #include "ir/dtype/type_id.h" +#include "include/errorcode.h" namespace mindspore { namespace lite { @@ -37,7 +40,15 @@ std::string GetMSOpType(tflite::BuiltinOperator tfliteOpType); TypeId GetTfliteDataType(const tflite::TensorType &tflite_data_type); -void Split(const std::string &src_str, std::vector *dst_str, const std::string &chr); +STATUS getPaddingParam(const std::unique_ptr &tensor, + schema::PadMode pad_mode, + int strideH, int strideW, + int windowH, int windowW, + std::vector *params); + +void Split(const std::string &src_str, + std::vector *dst_str, + const std::string &chr); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc index d13b35a6d5..88d027e41a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc @@ -15,18 +15,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_where_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_where_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteWhereParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteWhereParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,7 +41,6 @@ STATUS TfliteWhereParser::Parse(const std::unique_ptr &tflite return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteWhereParser"; std::unique_ptr attr(new schema::WhereT()); if (GetTfliteData(tflite_op->inputs[0], tflite_tensors, tflite_model_buffer, attr->condition)) { @@ -47,6 +50,13 @@ STATUS TfliteWhereParser::Parse(const std::unique_ptr &tflite op->primitive->value.type = schema::PrimitiveType_Where; op->primitive->value.value = attr.release(); + + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h index 3a707d0a8b..583b8dffe6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_WHERE_PARSER_H -#define LITE_TFLITE_WHERE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_WHERE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_WHERE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteWhereParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_WHERE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_WHERE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc index fb88ba965b..d5c1750d59 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc @@ -15,18 +15,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_zeros_like_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_zeros_like_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteZerosLikeParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteZerosLikeParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,11 +41,15 @@ STATUS TfliteZerosLikeParser::Parse(const std::unique_ptr &tf return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteZerosLikeParser"; std::unique_ptr attr(new schema::ZerosLikeT()); op->primitive->value.type = schema::PrimitiveType_ZerosLike; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h index 0c656137d5..a8cec073fd 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_ZEROS_LIKE_PARSER_H -#define LITE_TFLITE_ZEROS_LIKE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ZEROS_LIKE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ZEROS_LIKE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteZerosLikeParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_ZEROS_LIKE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ZEROS_LIKE_PARSER_H