Browse Source

refactor tflite parsers.

append ut

refactor tflite parsers

modify tflite parser, ut and model

supplement caffe flatten parser

fix the weight tensor format of deconv bug

fix bug when idx=-1

fix the weight tensor format of depthConv bug.
tags/v0.7.0-beta
lyvette 5 years ago
parent
commit
123c2024a5
100 changed files with 1319 additions and 1090 deletions
  1. +14
    -14
      mindspore/lite/test/run_test.sh
  2. +0
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add.tflite
  3. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add2.tflite
  4. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add3.tflite
  5. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/argmax.tflite
  6. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/concat.tflite
  7. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/conv.tflite
  8. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/deconv.tflite
  9. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv1.tflite
  10. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv2.tflite
  11. +0
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div.tflite
  12. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div2.tflite
  13. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div3.tflite
  14. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/hardswish.tflite
  15. +0
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logistic.tflite
  16. +0
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mean_pooling.tflite
  17. +0
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul.tflite
  18. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul2.tflite
  19. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul3.tflite
  20. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/realdiv.tflite
  21. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/slice.tflite
  22. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/stack.tflite
  23. +0
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub.tflite
  24. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub2.tflite
  25. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub3.tflite
  26. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/transpose.tflite
  27. +59
    -12
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_activation_parser_test.cc
  28. +2
    -4
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_addn_parser_test.cc
  29. +44
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmax_parser_test.cc
  30. +2
    -3
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmin_parser_test.cc
  31. +39
    -202
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_arithmetic_parser_test.cc
  32. +5
    -7
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_batch_to_space_nd_parser_test.cc
  33. +3
    -6
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cast_parser_test.cc
  34. +41
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_concat_parser_test.cc
  35. +57
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc
  36. +58
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc
  37. +3
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depth_to_space_parser_test.cc
  38. +92
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc
  39. +3
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_fill_parser_test.cc
  40. +2
    -3
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_nd_parser_test.cc
  41. +2
    -3
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_parser_test.cc
  42. +2
    -3
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_lrn_parser_test.cc
  43. +2
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_one_hot_parser_test.cc
  44. +2
    -4
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pad_parser_test.cc
  45. +2
    -10
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pooling_parser_test.cc
  46. +10
    -30
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_parser_test.cc
  47. +2
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reshape_parser_test.cc
  48. +4
    -8
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_resize_parser_test.cc
  49. +2
    -4
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_parser_test.cc
  50. +5
    -7
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_sequence_parser_test.cc
  51. +45
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_slice_parser_test.cc
  52. +3
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_softmax_parser_test.cc
  53. +5
    -7
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser_test.cc
  54. +3
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_depth_parser_test.cc
  55. +8
    -10
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sparse_to_dense_parser_test.cc
  56. +5
    -7
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_parser_test.cc
  57. +5
    -7
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_v_parser_test.cc
  58. +44
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_stack_parser_test.cc
  59. +13
    -15
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_strided_slice_parser_test.cc
  60. +3
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_tile_parser_test.cc
  61. +4
    -7
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc
  62. +43
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_transpose_parser_test.cc
  63. +3
    -4
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unique_parser_test.cc
  64. +3
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unstack_parser_test.cc
  65. +2
    -2
      mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc
  66. +8
    -3
      mindspore/lite/tools/converter/parser/caffe/caffe_flatten_parser.cc
  67. +68
    -24
      mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc
  68. +32
    -14
      mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h
  69. +20
    -8
      mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc
  70. +8
    -6
      mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h
  71. +18
    -10
      mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc
  72. +11
    -9
      mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h
  73. +20
    -9
      mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc
  74. +11
    -9
      mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h
  75. +84
    -136
      mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc
  76. +25
    -18
      mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h
  77. +16
    -11
      mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc
  78. +8
    -5
      mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h
  79. +16
    -7
      mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc
  80. +8
    -6
      mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h
  81. +19
    -12
      mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc
  82. +8
    -6
      mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h
  83. +20
    -9
      mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc
  84. +11
    -9
      mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h
  85. +48
    -42
      mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc
  86. +11
    -9
      mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h
  87. +1
    -0
      mindspore/lite/tools/converter/parser/tflite/tflite_converter.h
  88. +44
    -18
      mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc
  89. +8
    -6
      mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h
  90. +16
    -8
      mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc
  91. +8
    -6
      mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h
  92. +43
    -87
      mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc
  93. +11
    -14
      mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h
  94. +22
    -19
      mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc
  95. +11
    -8
      mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h
  96. +9
    -8
      mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc
  97. +11
    -9
      mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h
  98. +0
    -75
      mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc
  99. +0
    -39
      mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h
  100. +19
    -12
      mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc

+ 14
- 14
mindspore/lite/test/run_test.sh View File

@@ -12,19 +12,19 @@ cp -r ${CUR_DIR}/ut/tools/converter/parser/tflite/test_data/* ./
TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/
cp -fr $TEST_DATA_DIR/testPK ./data

#./lite-test --gtest_filter="*MindDataTestTensorDE*"
#./lite-test --gtest_filter="*MindDataTestEager*"
#
#./lite-test --gtest_filter="TestTfliteParser*"
#
#./lite-test --gtest_filter="*TestHebing*"
#
#./lite-test --gtest_filter=TestFcFp32*
#./lite-test --gtest_filter=TestConv1x1Fp32*
#./lite-test --gtest_filter=TestStrassenFp32*
#./lite-test --gtest_filter=TestDeConvolutionFp32*
#
#./lite-test --gtest_filter=TestPadInt8.*
#./lite-test --gtest_filter=TestDeconvInt8.*
./lite-test --gtest_filter="*MindDataTestTensorDE*"
./lite-test --gtest_filter="*MindDataTestEager*"
./lite-test --gtest_filter="TestTfliteParser*"
./lite-test --gtest_filter="*TestHebing*"
./lite-test --gtest_filter=TestFcFp32*
./lite-test --gtest_filter=TestConv1x1Fp32*
./lite-test --gtest_filter=TestStrassenFp32*
./lite-test --gtest_filter=TestDeConvolutionFp32*
./lite-test --gtest_filter=TestPadInt8.*
./lite-test --gtest_filter=TestDeconvInt8.*

./lite-test --gtest_filter="TestTfliteParser*"

mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add1.tflite → mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add2.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add3.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/argmax.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/concat.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/conv.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/deconv.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv1.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv2.tflite View File


mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div1.tflite → mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div2.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div3.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/hardswish.tflite View File


mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sigmoid.tflite → mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logistic.tflite View File


mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/avg_pooling.tflite → mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mean_pooling.tflite View File


mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul1.tflite → mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul2.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul3.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/realdiv.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/slice.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/stack.tflite View File


mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub1.tflite → mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub2.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub3.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/transpose.tflite View File


+ 59
- 12
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_activation_parser_test.cc View File

@@ -31,6 +31,12 @@ TEST_F(TestTfliteParserRelu, OpType) {
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
}

TEST_F(TestTfliteParserRelu, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
ASSERT_EQ(val->type, schema::ActivationType_RELU);
}

class TestTfliteParserRelu6 : public TestTfliteParser {
public:
TestTfliteParserRelu6() = default;
@@ -43,6 +49,12 @@ TEST_F(TestTfliteParserRelu6, OpType) {
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
}

TEST_F(TestTfliteParserRelu6, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
ASSERT_EQ(val->type, schema::ActivationType_RELU6);
}

class TestTfliteParserTanh : public TestTfliteParser {
public:
TestTfliteParserTanh() = default;
@@ -55,7 +67,45 @@ TEST_F(TestTfliteParserTanh, OpType) {
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
}

// logistic
TEST_F(TestTfliteParserTanh, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
ASSERT_EQ(val->type, schema::ActivationType_TANH);
}

class TestTfliteParserLogistic : public TestTfliteParser {
public:
TestTfliteParserLogistic() = default;
void SetUp() override { meta_graph = LoadAndConvert("./logistic.tflite", ""); }
};

TEST_F(TestTfliteParserLogistic, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
}
TEST_F(TestTfliteParserLogistic, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
ASSERT_EQ(val->type, schema::ActivationType_SIGMOID);
}

class TestTfliteParserHardSwish : public TestTfliteParser {
public:
TestTfliteParserHardSwish() = default;
void SetUp() override { meta_graph = LoadAndConvert("./hardswish.tflite", ""); }
};

TEST_F(TestTfliteParserHardSwish, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
}
TEST_F(TestTfliteParserHardSwish, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
ASSERT_EQ(val->type, schema::ActivationType_SIGMOID);
}

class TestTfliteParserPrelu : public TestTfliteParser {
public:
@@ -73,12 +123,11 @@ TEST_F(TestTfliteParserPrelu, OpType) {
}

TEST_F(TestTfliteParserPrelu, AttrValue) {
std::vector<float> slope(20, 0);
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPrelu(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsPrelu()->slope, slope);
auto val = meta_graph->nodes.front()->primitive->value;
std::vector<float> slope(20, 0);
ASSERT_EQ(val.AsPrelu()->slope, slope);
ASSERT_EQ(val.type, schema::PrimitiveType_Prelu);
}

class TestTfliteParserLeakyRelu : public TestTfliteParser {
@@ -94,12 +143,10 @@ TEST_F(TestTfliteParserLeakyRelu, OpType) {
}

TEST_F(TestTfliteParserLeakyRelu, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

auto val = meta_graph->nodes.front()->primitive->value.AsLeakyReLU();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->negativeSlope, 0.20000000298023224);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsLeakyReLU(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value;
ASSERT_EQ(val.AsLeakyReLU()->negativeSlope, 0.20000000298023224);
ASSERT_EQ(val.type, schema::PrimitiveType_LeakyReLU);
}

} // namespace mindspore

+ 2
- 4
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_addn_parser_test.cc View File

@@ -35,10 +35,8 @@ TEST_F(TestTfliteParserAddN, OpType) {
}

TEST_F(TestTfliteParserAddN, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsAddN(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsAddN()->N, 4);
auto val = meta_graph->nodes.front()->primitive->value.AsAddN();
ASSERT_EQ(val->N, 4);
}
} // namespace mindspore

+ 44
- 0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmax_parser_test.cc View File

@@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"

namespace mindspore {
class TestTfliteParserArgmax : public TestTfliteParser {
public:
TestTfliteParserArgmax() = default;
void SetUp() override { meta_graph = LoadAndConvert("./argmax.tflite", ""); }
};

TEST_F(TestTfliteParserArgmax, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMax) << "wrong Op Type";
}

TEST_F(TestTfliteParserArgmax, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsArgMax(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsArgMax();
ASSERT_EQ(val->axis, 1);
ASSERT_EQ(val->topK, 1);
ASSERT_EQ(val->axisType, 1);
ASSERT_EQ(val->keepDims, false);
ASSERT_EQ(val->outMaxValue, false);
}

} // namespace mindspore

+ 2
- 3
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmin_parser_test.cc View File

@@ -25,15 +25,14 @@ class TestTfliteParserArgmin : public TestTfliteParser {
};

TEST_F(TestTfliteParserArgmin, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMin) << "wrong Op Type";
}

TEST_F(TestTfliteParserArgmin, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsArgMin(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsArgMin();
ASSERT_EQ(val->axis, 1);
ASSERT_EQ(val->topK, 1);


+ 39
- 202
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_arithmetic_parser_test.cc View File

@@ -19,234 +19,57 @@

namespace mindspore {
// doubleInputOp
class TestTfliteParserAdd1 : public TestTfliteParser {
class TestTfliteParserAdd : public TestTfliteParser {
public:
TestTfliteParserAdd1() = default;
void SetUp() override { meta_graph = LoadAndConvert("./add1.tflite", ""); }
TestTfliteParserAdd() = default;
void SetUp() override { meta_graph = LoadAndConvert("./add.tflite", ""); }
};

TEST_F(TestTfliteParserAdd1, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type";
}

TEST_F(TestTfliteParserAdd1, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserAdd2 : public TestTfliteParser {
public:
TestTfliteParserAdd2() = default;
void SetUp() override { meta_graph = LoadAndConvert("./add2.tflite", ""); }
};

TEST_F(TestTfliteParserAdd2, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type";
}

TEST_F(TestTfliteParserAdd2, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserAdd3 : public TestTfliteParser {
public:
TestTfliteParserAdd3() = default;
void SetUp() override { meta_graph = LoadAndConvert("./add3.tflite", ""); }
};

TEST_F(TestTfliteParserAdd3, OpType) {
TEST_F(TestTfliteParserAdd, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type";
}

TEST_F(TestTfliteParserAdd3, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserSub1 : public TestTfliteParser {
public:
TestTfliteParserSub1() = default;
void SetUp() override { meta_graph = LoadAndConvert("./sub1.tflite", ""); }
};

TEST_F(TestTfliteParserSub1, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type";
}

TEST_F(TestTfliteParserSub1, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserSub2 : public TestTfliteParser {
public:
TestTfliteParserSub2() = default;
void SetUp() override { meta_graph = LoadAndConvert("./sub2.tflite", ""); }
};

TEST_F(TestTfliteParserSub2, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type";
}

TEST_F(TestTfliteParserSub2, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserSub3 : public TestTfliteParser {
class TestTfliteParserSub : public TestTfliteParser {
public:
TestTfliteParserSub3() = default;
void SetUp() override { meta_graph = LoadAndConvert("./sub3.tflite", ""); }
TestTfliteParserSub() = default;
void SetUp() override { meta_graph = LoadAndConvert("./sub.tflite", ""); }
};

TEST_F(TestTfliteParserSub3, OpType) {
TEST_F(TestTfliteParserSub, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type";
}

TEST_F(TestTfliteParserSub3, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserMul1 : public TestTfliteParser {
public:
TestTfliteParserMul1() = default;
void SetUp() override { meta_graph = LoadAndConvert("./mul1.tflite", ""); }
};

TEST_F(TestTfliteParserMul1, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type";
}

TEST_F(TestTfliteParserMul1, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserMul2 : public TestTfliteParser {
class TestTfliteParserMul : public TestTfliteParser {
public:
TestTfliteParserMul2() = default;
void SetUp() override { meta_graph = LoadAndConvert("./mul2.tflite", ""); }
TestTfliteParserMul() = default;
void SetUp() override { meta_graph = LoadAndConvert("./mul.tflite", ""); }
};

TEST_F(TestTfliteParserMul2, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type";
}

TEST_F(TestTfliteParserMul2, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserMul3 : public TestTfliteParser {
public:
TestTfliteParserMul3() = default;
void SetUp() override { meta_graph = LoadAndConvert("./mul3.tflite", ""); }
};

TEST_F(TestTfliteParserMul3, OpType) {
TEST_F(TestTfliteParserMul, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type";
}

TEST_F(TestTfliteParserMul3, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserDiv1 : public TestTfliteParser {
public:
TestTfliteParserDiv1() = default;
void SetUp() override { meta_graph = LoadAndConvert("./div1.tflite", ""); }
};

TEST_F(TestTfliteParserDiv1, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
}

TEST_F(TestTfliteParserDiv1, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserDiv2 : public TestTfliteParser {
class TestTfliteParserDiv : public TestTfliteParser {
public:
TestTfliteParserDiv2() = default;
void SetUp() override { meta_graph = LoadAndConvert("./div2.tflite", ""); }
TestTfliteParserDiv() = default;
void SetUp() override { meta_graph = LoadAndConvert("./div.tflite", ""); }
};

TEST_F(TestTfliteParserDiv2, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
}

TEST_F(TestTfliteParserDiv2, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserDiv3 : public TestTfliteParser {
public:
TestTfliteParserDiv3() = default;
void SetUp() override { meta_graph = LoadAndConvert("./div3.tflite", ""); }
};

TEST_F(TestTfliteParserDiv3, OpType) {
TEST_F(TestTfliteParserDiv, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
}

TEST_F(TestTfliteParserDiv3, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserFloorDiv : public TestTfliteParser {
public:
TestTfliteParserFloorDiv() = default;
@@ -254,6 +77,7 @@ class TestTfliteParserFloorDiv : public TestTfliteParser {
};

TEST_F(TestTfliteParserFloorDiv, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorDiv) << "wrong Op Type";
@@ -266,12 +90,26 @@ class TestTfliteParserFloorMod : public TestTfliteParser {
};

TEST_F(TestTfliteParserFloorMod, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorMod) << "wrong Op Type";
}

// realDiv
class TestTfliteParserRealDiv : public TestTfliteParser {
public:
TestTfliteParserRealDiv() = default;
void SetUp() override {
meta_graph = LoadAndConvert("./realdiv.tflite");
}
};

TEST_F(TestTfliteParserRealDiv, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
}

class TestTfliteParserSquaredDifference : public TestTfliteParser {
public:
@@ -296,17 +134,15 @@ class TestTfliteParserPow : public TestTfliteParser {
};

TEST_F(TestTfliteParserPow, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Power) << "wrong Op Type";
}

TEST_F(TestTfliteParserPow, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPower(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsPower();

ASSERT_EQ(val->scale, 1.0);
ASSERT_EQ(val->shift, 0.0);
ASSERT_EQ(val->power, 0.0);
@@ -477,6 +313,7 @@ class TestTfliteParserFloor : public TestTfliteParser {
};

TEST_F(TestTfliteParserFloor, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Floor) << "wrong Op Type";


+ 5
- 7
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_batch_to_space_nd_parser_test.cc View File

@@ -32,14 +32,12 @@ TEST_F(TestTfliteParserBatchToSpaceNd, OpType) {
}

TEST_F(TestTfliteParserBatchToSpaceNd, AttrValue) {
const std::vector<int> blockShape{2, 2};
const std::vector<int> crops{0, 0, 2, 0};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsBatchToSpace(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->blockShape, blockShape);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->crops, crops);
auto val = meta_graph->nodes.front()->primitive->value.AsBatchToSpace();
const std::vector<int> blockShape = {2, 2};
ASSERT_EQ(val->blockShape, blockShape);
const std::vector<int> crops = {0, 0, 2, 0};
ASSERT_EQ(val->crops, crops);
}

} // namespace mindspore

+ 3
- 6
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cast_parser_test.cc View File

@@ -35,12 +35,9 @@ TEST_F(TestTfliteParserCast, OpType) {
}

TEST_F(TestTfliteParserCast, AttrValue) {
// float32 --> int32
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsCast(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->srcT, 43);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->dstT, 34);
auto val = meta_graph->nodes.front()->primitive->value.AsCast();
ASSERT_EQ(val->srcT, 43);
ASSERT_EQ(val->dstT, 34);
}
} // namespace mindspore

+ 41
- 0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_concat_parser_test.cc View File

@@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"

namespace mindspore {
class TestTfliteParserConcat : public TestTfliteParser {
public:
TestTfliteParserConcat() = default;
void SetUp() override { meta_graph = LoadAndConvert("./concat.tflite", ""); }
};

TEST_F(TestTfliteParserConcat, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Concat) << "wrong Op Type";
}

TEST_F(TestTfliteParserConcat, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConcat(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsConcat();
ASSERT_EQ(val->axis, 1);
ASSERT_EQ(val->n, 2);
}

} // namespace mindspore

+ 57
- 0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc View File

@@ -0,0 +1,57 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"

namespace mindspore {
class TestTfliteParserConv : public TestTfliteParser {
public:
TestTfliteParserConv() = default;
void SetUp() override { meta_graph = LoadAndConvert("./conv.tflite", ""); }
};

TEST_F(TestTfliteParserConv, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
}

TEST_F(TestTfliteParserConv, AttrValue) {
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr);
auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D();
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->group, 1);
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
ASSERT_EQ(val->hasBias, true);
ASSERT_EQ(val->channelIn, 1);
ASSERT_EQ(val->channelOut, 4);
ASSERT_EQ(val->kernelH, 3);
ASSERT_EQ(val->kernelW, 3);
ASSERT_EQ(val->strideH, 1);
ASSERT_EQ(val->strideW, 1);
ASSERT_EQ(val->dilateH, 1);
ASSERT_EQ(val->dilateW, 1);
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
ASSERT_EQ(val->padUp, 1);
ASSERT_EQ(val->padDown, 1);
ASSERT_EQ(val->padLeft, 1);
ASSERT_EQ(val->padRight, 1);
}

} // namespace mindspore

+ 58
- 0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc View File

@@ -0,0 +1,58 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"

namespace mindspore {
class TestTfliteParserDeConv : public TestTfliteParser {
public:
TestTfliteParserDeConv() = default;
void SetUp() override { meta_graph = LoadAndConvert("./deconv.tflite", ""); }
};

TEST_F(TestTfliteParserDeConv, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DeConv2D) << "wrong Op Type";
}

TEST_F(TestTfliteParserDeConv, AttrValue) {
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(), nullptr);
auto val = meta_graph->nodes.at(1)->primitive->value.AsDeConv2D();
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->group, 1);
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
ASSERT_EQ(val->hasBias, true);

ASSERT_EQ(val->channelIn, 1);
ASSERT_EQ(val->channelOut, 4);
ASSERT_EQ(val->kernelH, 3);
ASSERT_EQ(val->kernelW, 3);
ASSERT_EQ(val->strideH, 1);
ASSERT_EQ(val->strideW, 1);
ASSERT_EQ(val->dilateH, 1);
ASSERT_EQ(val->dilateW, 1);
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
ASSERT_EQ(val->padUp, 1);
ASSERT_EQ(val->padDown, 1);
ASSERT_EQ(val->padLeft, 1);
ASSERT_EQ(val->padRight, 1);
}

} // namespace mindspore

+ 3
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depth_to_space_parser_test.cc View File

@@ -35,11 +35,9 @@ TEST_F(TestTfliteParserDepthToSpace, OpType) {
}

TEST_F(TestTfliteParserDepthToSpace, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthToSpace(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->blockSize, 4);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->format, schema::Format_NHWC);
auto val = meta_graph->nodes.front()->primitive->value.AsDepthToSpace();
ASSERT_EQ(val->blockSize, 4);
ASSERT_EQ(val->format, schema::Format_NHWC);
}
} // namespace mindspore

+ 92
- 0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc View File

@@ -0,0 +1,92 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"

namespace mindspore {
class TestTfliteParserDepthwiseConv1 : public TestTfliteParser {
public:
TestTfliteParserDepthwiseConv1() = default;
void SetUp() override { meta_graph = LoadAndConvert("./depthwise_conv1.tflite", ""); }
};

TEST_F(TestTfliteParserDepthwiseConv1, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
}

TEST_F(TestTfliteParserDepthwiseConv1, AttrValue) {
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr);
auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D();
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->group, 0);
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
ASSERT_EQ(val->hasBias, true);
ASSERT_EQ(val->channelIn, 1);
ASSERT_EQ(val->channelOut, 4);
ASSERT_EQ(val->kernelH, 3);
ASSERT_EQ(val->kernelW, 3);
ASSERT_EQ(val->strideH, 1);
ASSERT_EQ(val->strideW, 1);
ASSERT_EQ(val->dilateH, 1);
ASSERT_EQ(val->dilateW, 1);
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
ASSERT_EQ(val->padUp, 1);
ASSERT_EQ(val->padDown, 1);
ASSERT_EQ(val->padLeft, 1);
ASSERT_EQ(val->padRight, 1);
}

class TestTfliteParserDepthwiseConv2 : public TestTfliteParser {
public:
TestTfliteParserDepthwiseConv2() = default;
void SetUp() override { meta_graph = LoadAndConvert("./depthwise_conv2.tflite", ""); }
};

TEST_F(TestTfliteParserDepthwiseConv2, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DepthwiseConv2D) << "wrong Op Type";
}

TEST_F(TestTfliteParserDepthwiseConv2, AttrValue) {
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(), nullptr);
auto val = meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D();
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
ASSERT_EQ(val->hasBias, true);
ASSERT_EQ(val->channelIn, 2);
ASSERT_EQ(val->channelMultiplier, 1);
ASSERT_EQ(val->kernelH, 3);
ASSERT_EQ(val->kernelW, 3);
ASSERT_EQ(val->strideH, 1);
ASSERT_EQ(val->strideW, 1);
ASSERT_EQ(val->dilateH, 1);
ASSERT_EQ(val->dilateW, 1);
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
ASSERT_EQ(val->padUp, 1);
ASSERT_EQ(val->padDown, 1);
ASSERT_EQ(val->padLeft, 1);
ASSERT_EQ(val->padRight, 1);
}

} // namespace mindspore

+ 3
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_fill_parser_test.cc View File

@@ -25,17 +25,15 @@ class TestTfliteParserFill : public TestTfliteParser {
};

TEST_F(TestTfliteParserFill, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Fill) << "wrong Op Type";
}

TEST_F(TestTfliteParserFill, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

TEST_F(TestTfliteParserFill, AttrValue) {;
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsFill(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsFill();

std::vector<int32_t> dims = {9};
ASSERT_EQ(val->dims, dims);
}


+ 2
- 3
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_nd_parser_test.cc View File

@@ -25,15 +25,14 @@ class TestTfliteParserGatherNd : public TestTfliteParser {
};

TEST_F(TestTfliteParserGatherNd, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_GatherNd) << "wrong Op Type";
}

TEST_F(TestTfliteParserGatherNd, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGatherNd(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsGatherNd();
ASSERT_EQ(val->batchDims, 0);
}


+ 2
- 3
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_parser_test.cc View File

@@ -25,15 +25,14 @@ class TestTfliteParserGather : public TestTfliteParser {
};

TEST_F(TestTfliteParserGather, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Gather) << "wrong Op Type";
}

TEST_F(TestTfliteParserGather, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGather(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsGather();
ASSERT_EQ(val->axis, 0);
ASSERT_EQ(val->batchDims, 0);


+ 2
- 3
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_lrn_parser_test.cc View File

@@ -25,6 +25,7 @@ class TestTfliteParserLRN : public TestTfliteParser {
};

TEST_F(TestTfliteParserLRN, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type,
@@ -32,9 +33,7 @@ TEST_F(TestTfliteParserLRN, OpType) {
}

TEST_F(TestTfliteParserLRN, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization();
ASSERT_EQ(val->alpha, 1);
ASSERT_EQ(val->beta, 0.5);


+ 2
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_one_hot_parser_test.cc View File

@@ -32,12 +32,9 @@ TEST_F(TestTfliteParserOneHot, OpType) {
}

TEST_F(TestTfliteParserOneHot, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsOneHot(), nullptr);
// in OneHot parser axis = axis > 0 ? axis : axis + tensor_shape.size()
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsOneHot()->axis, 2);
auto val = meta_graph->nodes.front()->primitive->value.AsOneHot();
ASSERT_EQ(val->axis, 2);
}

} // namespace mindspore

+ 2
- 4
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pad_parser_test.cc View File

@@ -25,17 +25,15 @@ class TestTfliteParserPad : public TestTfliteParser {
};

TEST_F(TestTfliteParserPad, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Pad) << "wrong Op Type";
}

TEST_F(TestTfliteParserPad, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPad(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsPad();

std::vector<int32_t> paddings = {1, 1, 2, 2, 3, 3, 4, 4};
ASSERT_EQ(val->paddings, paddings);
}


+ 2
- 10
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pooling_parser_test.cc View File

@@ -35,12 +35,8 @@ TEST_F(TestTfliteParserMaxPooling, OpType) {
}

TEST_F(TestTfliteParserMaxPooling, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPooling(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsPooling();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->poolingMode, schema::PoolMode_MAX_POOLING);
ASSERT_EQ(val->global, false);
@@ -72,12 +68,8 @@ TEST_F(TestTfliteParserAvgPooling, OpType) {
}

TEST_F(TestTfliteParserAvgPooling, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPooling(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsPooling();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->poolingMode, schema::PoolMode_MEAN_POOLING);
ASSERT_EQ(val->global, false);


+ 10
- 30
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_parser_test.cc View File

@@ -32,13 +32,9 @@ TEST_F(TestTfliteParserReduceMax, OpType) {
}

TEST_F(TestTfliteParserReduceMax, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMax) << "wrong reduce mode";
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMax);
ASSERT_EQ(val->keepDims, false);
std::vector<int32_t> axes = {2};
ASSERT_EQ(val->axes, axes);
@@ -58,13 +54,9 @@ TEST_F(TestTfliteParserReduceMin, OpType) {
}

TEST_F(TestTfliteParserReduceMin, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMin) << "wrong reduce mode";
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMin);
ASSERT_EQ(val->keepDims, false);
std::vector<int32_t> axes = {2};
ASSERT_EQ(val->axes, axes);
@@ -84,13 +76,9 @@ TEST_F(TestTfliteParserReduceProd, OpType) {
}

TEST_F(TestTfliteParserReduceProd, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceProd) << "wrong reduce mode";
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceProd);
ASSERT_EQ(val->keepDims, false);
std::vector<int32_t> axes = {2};
ASSERT_EQ(val->axes, axes);
@@ -111,13 +99,9 @@ TEST_F(TestTfliteParserSum, OpType) {
}

TEST_F(TestTfliteParserSum, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceSum) << "wrong reduce mode";
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceSum);
ASSERT_EQ(val->keepDims, false);
std::vector<int32_t> axes = {2};
ASSERT_EQ(val->axes, axes);
@@ -138,13 +122,9 @@ TEST_F(TestTfliteParserMean, OpType) {
}

TEST_F(TestTfliteParserMean, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMean) << "wrong reduce mode";
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMean);
ASSERT_EQ(val->keepDims, true);
std::vector<int32_t> axes = {2, 3};
ASSERT_EQ(val->axes, axes);


+ 2
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reshape_parser_test.cc View File

@@ -35,12 +35,9 @@ TEST_F(TestTfliteParserReshape, OpType) {
}

TEST_F(TestTfliteParserReshape, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReshape(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReshape();
std::vector<int64_t> shape = {3, 5, 20};
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReshape()->shape, shape); // int32
ASSERT_EQ(val->shape, shape);
}
} // namespace mindspore

+ 4
- 8
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_resize_parser_test.cc View File

@@ -26,17 +26,15 @@ class TestTfliteParserResizeNN : public TestTfliteParser {
};

TEST_F(TestTfliteParserResizeNN, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type";
}

TEST_F(TestTfliteParserResizeNN, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsResize(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsResize();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->alignCorners, false);
ASSERT_EQ(val->newHeight, 3);
ASSERT_EQ(val->newWidth, 100);
@@ -52,17 +50,15 @@ class TestTfliteParserResizeBilinear : public TestTfliteParser {
};

TEST_F(TestTfliteParserResizeBilinear, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type";
}

TEST_F(TestTfliteParserResizeBilinear, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsResize(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsResize();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->alignCorners, false);
ASSERT_EQ(val->newHeight, 75);
ASSERT_EQ(val->newWidth, 4);


+ 2
- 4
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_parser_test.cc View File

@@ -25,17 +25,15 @@ class TestTfliteParserReverse : public TestTfliteParser {
};

TEST_F(TestTfliteParserReverse, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reverse) << "wrong Op Type";
}

TEST_F(TestTfliteParserReverse, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverse(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReverse();

std::vector<int32_t> axis = {3};
ASSERT_EQ(val->axis, axis);
}


+ 5
- 7
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_sequence_parser_test.cc View File

@@ -35,13 +35,11 @@ TEST_F(TestTfliteParserReverseSequence, OpType) {
}

TEST_F(TestTfliteParserReverseSequence, AttrValue) {
std::vector<int> seq_length{7, 2, 3, 5};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverseSequence(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqLengths, seq_length);
auto val = meta_graph->nodes.front()->primitive->value.AsReverseSequence();
ASSERT_EQ(val->seqAxis, 1);
ASSERT_EQ(val->seqAxis, 1);
std::vector<int> seq_length = {7, 2, 3, 5};
ASSERT_EQ(val->seqLengths, seq_length);
}
} // namespace mindspore

+ 45
- 0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_slice_parser_test.cc View File

@@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"

namespace mindspore {
class TestTfliteParserSlice : public TestTfliteParser {
public:
TestTfliteParserSlice() = default;

void SetUp() override { meta_graph = LoadAndConvert("./slice.tflite"); }
};

TEST_F(TestTfliteParserSlice, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Slice) << "wrong Op Type";
}

TEST_F(TestTfliteParserSlice, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSlice(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsSlice();
ASSERT_EQ(val->format, schema::Format_NHWC);
std::vector<int32_t> begin = {1, 0, 0};
ASSERT_EQ(val->begin, begin);
std::vector<int32_t> size = {1, 1, 3};
ASSERT_EQ(val->size, size);
}

} // namespace mindspore

+ 3
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_softmax_parser_test.cc View File

@@ -35,11 +35,9 @@ TEST_F(TestTfliteParserSoftmax, OpType) {
}

TEST_F(TestTfliteParserSoftmax, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSoftMax(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSoftMax()->axis, -1);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSoftMax(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsSoftMax();
ASSERT_EQ(val->axis, -1);
}

} // namespace mindspore

+ 5
- 7
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser_test.cc View File

@@ -35,13 +35,11 @@ TEST_F(TestTfliteParserSpaceToBatchND, OpType) {
}

TEST_F(TestTfliteParserSpaceToBatchND, AttrValue) {
std::vector<int> blockshape{2, 2};
std::vector<int> padding{0, 0, 2, 0};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->blockShape, blockshape);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->paddings, padding);
auto val = meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND();
std::vector<int> blockshape = {2, 2};
ASSERT_EQ(val->blockShape, blockshape);
std::vector<int> padding = {0, 0, 2, 0};
ASSERT_EQ(val->paddings, padding);
}
} // namespace mindspore

+ 3
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_depth_parser_test.cc View File

@@ -35,11 +35,9 @@ TEST_F(TestTfliteParserSpaceToDepth, OpType) {
}

TEST_F(TestTfliteParserSpaceToDepth, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->blockSize, 2);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->format, schema::Format_NHWC);
auto val = meta_graph->nodes.front()->primitive->value.AsSpaceToDepth();
ASSERT_EQ(val->blockSize, 2);
ASSERT_EQ(val->format, schema::Format_NHWC);
}
} // namespace mindspore

+ 8
- 10
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sparse_to_dense_parser_test.cc View File

@@ -35,16 +35,14 @@ TEST_F(TestTfliteParserSparseToDense, OpType) {
}

TEST_F(TestTfliteParserSparseToDense, AttrValue) {
std::vector<int> outputShape{5, 5};
std::vector<int> sparseValue{1};
std::vector<int> defaultValue{0};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSparseToDense(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->outputShape, outputShape);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->sparseValue, sparseValue);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->defaultValue, defaultValue);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->validateIndices, false);
auto val = meta_graph->nodes.front()->primitive->value.AsSparseToDense();
std::vector<int> outputShape = {5, 5};
ASSERT_EQ(val->outputShape, outputShape);
std::vector<int> sparseValue = {1};
ASSERT_EQ(val->sparseValue, sparseValue);
std::vector<int> defaultValue = {0};
ASSERT_EQ(val->defaultValue, defaultValue);
ASSERT_EQ(val->validateIndices, false);
}
} // namespace mindspore

+ 5
- 7
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_parser_test.cc View File

@@ -33,14 +33,12 @@ TEST_F(TestTfliteParserSplit, OpType) {
}

TEST_F(TestTfliteParserSplit, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr);
const std::vector<int> sizeSplits{2, 2};
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 2);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits);
auto val = meta_graph->nodes.front()->primitive->value.AsSplit();
ASSERT_EQ(val->splitDim, 2);
ASSERT_EQ(val->numberSplit, 2);
const std::vector<int> sizeSplits = {2, 2};
ASSERT_EQ(val->sizeSplits, sizeSplits);
}

} // namespace mindspore

+ 5
- 7
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_v_parser_test.cc View File

@@ -33,14 +33,12 @@ TEST_F(TestTfliteParserSplitV, OpType) {
}

TEST_F(TestTfliteParserSplitV, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr);
const std::vector<int> sizeSplits{1, 3};
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits);
auto val = meta_graph->nodes.front()->primitive->value.AsSplit();
ASSERT_EQ(val->splitDim, 0);
ASSERT_EQ(val->numberSplit, 2);
const std::vector<int> sizeSplits = {1, 3};
ASSERT_EQ(val->sizeSplits, sizeSplits);
}

} // namespace mindspore

+ 44
- 0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_stack_parser_test.cc View File

@@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"

namespace mindspore {
class TestTfliteParserStack : public TestTfliteParser {
public:
TestTfliteParserStack() = default;

void SetUp() override { meta_graph = LoadAndConvert("./stack.tflite"); }
};

TEST_F(TestTfliteParserStack, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Stack) << "wrong Op Type";
}

TEST_F(TestTfliteParserStack, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStack(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsStack();
ASSERT_EQ(val->axis, 1);
ASSERT_EQ(val->n, 2);
const std::vector<int> isScale = {3, 2, 3};
ASSERT_EQ(val->isScale, isScale);
}

} // namespace mindspore

+ 13
- 15
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_strided_slice_parser_test.cc View File

@@ -35,21 +35,19 @@ TEST_F(TestTfliteParserStridedSlice, OpType) {
}

TEST_F(TestTfliteParserStridedSlice, AttrValue) {
std::vector<int> begin{1, -1, 0};
std::vector<int> end{2, -3, 3};
std::vector<int> stride{1, -1, 1};
std::vector<int> isscale{3, 2, 3};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStridedSlice(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->endMask, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->begin, begin);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->end, end);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->stride, stride);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->isScale, isscale);
auto val = meta_graph->nodes.front()->primitive->value.AsStridedSlice();
ASSERT_EQ(val->beginMask, 0);
ASSERT_EQ(val->endMask, 0);
ASSERT_EQ(val->beginMask, 0);
ASSERT_EQ(val->beginMask, 0);
std::vector<int> begin = {1, -1, 0};
ASSERT_EQ(val->begin, begin);
std::vector<int> end = {2, -3, 3};
ASSERT_EQ(val->end, end);
std::vector<int> stride = {1, -1, 1};
ASSERT_EQ(val->stride, stride);
std::vector<int> isscale = {3, 2, 3};
ASSERT_EQ(val->isScale, isscale);
}
} // namespace mindspore

+ 3
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_tile_parser_test.cc View File

@@ -35,11 +35,9 @@ TEST_F(TestTfliteParserTile, OpType) {
}

TEST_F(TestTfliteParserTile, AttrValue) {
std::vector<int> multiply{2, 3, 4};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTile(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTile()->multiples, multiply);
auto val = meta_graph->nodes.front()->primitive->value.AsTile();
std::vector<int> multiply = {2, 3, 4};
ASSERT_EQ(val->multiples, multiply);
}
} // namespace mindspore

+ 4
- 7
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc View File

@@ -35,13 +35,10 @@ TEST_F(TestTfliteParserTopKV2, OpType) {
}

TEST_F(TestTfliteParserTopKV2, AttrValue) {
// attr->sorted default is true
std::vector<int> k{3};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->k, k);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->sorted, true);
auto val = meta_graph->nodes.front()->primitive->value.AsTopKV2();
std::vector<int> k = {3};
ASSERT_EQ(val->k, k);
ASSERT_EQ(val->sorted, true);
}
} // namespace mindspore

+ 43
- 0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_transpose_parser_test.cc View File

@@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"

namespace mindspore {
class TestTfliteParserTranspose : public TestTfliteParser {
public:
TestTfliteParserTranspose() = default;

void SetUp() override { meta_graph = LoadAndConvert("./transpose.tflite"); }
};

TEST_F(TestTfliteParserTranspose, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
}

TEST_F(TestTfliteParserTranspose, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTranspose(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsTranspose();
ASSERT_EQ(val->conjugate, false);
std::vector<int32_t> perm = {1, 0};
ASSERT_EQ(val->perm, perm);
}

} // namespace mindspore

+ 3
- 4
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unique_parser_test.cc View File

@@ -35,10 +35,9 @@ TEST_F(TestTfliteParserUnique, OpType) {
}

TEST_F(TestTfliteParserUnique, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnique()->outType, 34); // int32
auto val = meta_graph->nodes.front()->primitive->value.AsUnique();
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr);
ASSERT_EQ(val->outType, 34);
}
} // namespace mindspore

+ 3
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unstack_parser_test.cc View File

@@ -35,11 +35,9 @@ TEST_F(TestTfliteParserUnstack, OpType) {
}

TEST_F(TestTfliteParserUnstack, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnstack(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->num, 5);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->axis, 1);
auto val = meta_graph->nodes.front()->primitive->value.AsUnstack();
ASSERT_EQ(val->num, 5);
ASSERT_EQ(val->axis, 1);
}
} // namespace mindspore

+ 2
- 2
mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc View File

@@ -353,7 +353,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
} else if (weightTensor->format == schema::Format_KCHW) {
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_CHWK) {
} else if (weightTensor->format == schema::Format_CHWK) { // from tflite
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
@@ -369,7 +369,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_CHWK) { // from tf
} else if (weightTensor->format == schema::Format_CHWK) { // from tflite
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;


+ 8
- 3
mindspore/lite/tools/converter/parser/caffe/caffe_flatten_parser.cc View File

@@ -21,11 +21,16 @@ namespace lite {
STATUS CaffeFlattenParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight,
schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) {
if (op == nullptr) {
// MS_LOGE("null pointer dereferencing.");
// MS_LOG(ERROR) << "null pointer dereferencing.";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ReshapeT> attr(new schema::ReshapeT());
attr->format = schema::Format_NCHW;
std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT());
const caffe::FlattenParameter flattenParam = proto.flatten_param();

attr->axis = (int32_t)flattenParam.axis();
attr->useAxis = true;
attr->hasBias = false;
attr->activationType = schema::ActivationType_NO_ACTIVATION;

op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Flatten;


+ 68
- 24
mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc View File

@@ -14,18 +14,21 @@
* limitations under the License.
*/

#include "tools/converter/parser/tflite/tflite_activation_parser.h"
#include <memory>
#include <vector>
#include <string>
#include "tools/converter/parser/tflite/tflite_activation_parser.h"
#include <map>

namespace mindspore {
namespace lite {
STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -35,13 +38,11 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}

std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT());

std::vector<std::string> node_name_str;
Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();

if (std::strcmp(node_name, "Relu") == 0) {
MS_LOG(DEBUG) << "parse TfliteReluParser";
attr->type = schema::ActivationType_RELU;
@@ -54,29 +55,31 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
} else if (std::strcmp(node_name, "Logistic") == 0) {
MS_LOG(DEBUG) << "parse TfliteLogisticParser";
attr->type = schema::ActivationType_SIGMOID;
} else if (std::strcmp(node_name, "LeakyRelu") == 0) {
const auto &option = tfliteOp->builtin_options.AsLeakyReluOptions();
if (option == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->type = schema::ActivationType_LEAKY_RELU;
attr->alpha = option->alpha;
} else {
MS_LOG(ERROR) << "wrong activation type";
return RET_ERROR;
} else if (std::strcmp(node_name, "HardSwish") == 0) {
MS_LOG(DEBUG) << "parse TfliteHardSwishParser";
attr->type = schema::ActivationType_SIGMOID;
}

attr->alpha = 0.2f;
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}

STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) {
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TflitePreluParser";

if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -86,23 +89,64 @@ STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "paser TflitePreluParser";
std::unique_ptr<schema::PreluT> attr(new schema::PreluT());

if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) {
MS_LOG(ERROR) << "get pRelu -> slope failed";
return RET_ERROR;
}

op->primitive->value.type = schema::PrimitiveType_Prelu;
op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}

STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteLeakyReluParser";

if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}

std::unique_ptr<schema::LeakyReLUT> attr(new schema::LeakyReLUT());

const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->negativeSlope = tflite_attr->alpha;

op->primitive->value.type = schema::PrimitiveType_LeakyReLU;
op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}

TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser());
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser());
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser());
TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser());
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser());
TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser());
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser());


+ 32
- 14
mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h View File

@@ -14,13 +14,14 @@
* limitations under the License.
*/

#ifndef PREDICT_TFLITE_RELU_PARSER_H
#define PREDICT_TFLITE_RELU_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H

#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
#include <vector>
#include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

namespace mindspore {
namespace lite {
@@ -29,11 +30,13 @@ class TfliteActivationParser : public TfliteNodeParser {
public:
TfliteActivationParser() : TfliteNodeParser("node_name") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};

class TfliteReluParser : public TfliteActivationParser {
@@ -56,9 +59,9 @@ class TfliteLogisticParser : public TfliteActivationParser {
TfliteLogisticParser() : TfliteActivationParser() {}
};

class TfliteLeakyReluParser : public TfliteActivationParser {
class TfliteHardSwishParser : public TfliteActivationParser {
public:
TfliteLeakyReluParser() : TfliteActivationParser() {}
TfliteHardSwishParser() : TfliteActivationParser() {}
};

class TflitePreluParser : public TfliteNodeParser {
@@ -68,12 +71,27 @@ class TflitePreluParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantized_model) override;
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};

class TfliteLeakyReluParser : public TfliteNodeParser {
public:
TfliteLeakyReluParser() : TfliteNodeParser("LeakyRelu") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};

} // namespace lite
} // namespace mindspore

#endif // PREDICT_TFLITE_RELU_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H


+ 20
- 8
mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc View File

@@ -18,14 +18,20 @@
#include "tools/converter/parser/tflite/tflite_addn_parser.h"
#include <vector>
#include <memory>
#include <map>

namespace mindspore {
namespace lite {
STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteAddNParser";

// set attr
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -36,13 +42,19 @@ STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "parse TfliteAddNParser";
std::unique_ptr<schema::AddNT> attr(new schema::AddNT());

attr->N = tfliteTensors.size() - 1;

attr->N = tflite_tensors.size() - 1;
op->primitive->value.type = schema::PrimitiveType_AddN;
op->primitive->value.value = attr.release();

// set input
for (int i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}



+ 8
- 6
mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License.
*/

#ifndef LITE_TFLITE_ADDN_PARSER_H
#define LITE_TFLITE_ADDN_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H

#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

@@ -31,11 +32,12 @@ class TfliteAddNParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_TFLITE_ADDN_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H

+ 18
- 10
mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc View File

@@ -17,16 +17,19 @@
#include "tools/converter/parser/tflite/tflite_argmax_parser.h"
#include <memory>
#include <vector>
#include <map>

namespace mindspore {
namespace lite {
STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteArgmaxParser";

if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -37,7 +40,6 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "parse TfliteArgmaxParser";
std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT());

attr->outMaxValue = false;
@@ -45,9 +47,10 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
attr->keepDims = false;
attr->axisType = 1;

auto axis_idx = tfliteOp->inputs[1];
std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){});
auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer];
// get axis attr
auto axis_idx = tflite_op->inputs[1];
std::for_each(tflite_tensors[axis_idx]->shape.begin(), tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha){});
auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer];
if (buf_data == nullptr) {
MS_LOG(ERROR) << "the buf data is null";
return RET_NULL_PTR;
@@ -61,6 +64,11 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit

op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}



+ 11
- 9
mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License.
*/

#ifndef PREDICT_TFLITE_ARGMAX_PARSER_H
#define PREDICT_TFLITE_ARGMAX_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H

#include <vector>
#include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

@@ -28,14 +29,15 @@ class TfliteArgmaxParser : public TfliteNodeParser {
public:
TfliteArgmaxParser() : TfliteNodeParser("Argmax") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore

#endif // PREDICT_TFLITE_ARGMAX_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H

+ 20
- 9
mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc View File

@@ -17,14 +17,19 @@
#include "tools/converter/parser/tflite/tflite_argmin_parser.h"
#include <memory>
#include <vector>
#include <map>

namespace mindspore {
namespace lite {
STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteArgminParser";

if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -35,7 +40,6 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "parse TfliteArgminParser";
std::unique_ptr<schema::ArgMinT> attr(new schema::ArgMinT());

attr->outMaxValue = false;
@@ -43,9 +47,11 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
attr->keepDims = false;
attr->axisType = 1;

auto axis_idx = tfliteOp->inputs[1];
std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){});
auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer];
// get axis attr
auto axis_idx = tflite_op->inputs[1];
std::for_each(tflite_tensors[axis_idx]->shape.begin(),
tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha){});
auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer];
if (buf_data == nullptr) {
MS_LOG(ERROR) << "the buf data is null";
return RET_NULL_PTR;
@@ -59,6 +65,11 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit

op->primitive->value.type = schema::PrimitiveType_ArgMin;
op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}



+ 11
- 9
mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License.
*/

#ifndef PREDICT_TFLITE_ARGMIN_PARSER_H
#define PREDICT_TFLITE_ARGMIN_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H

#include <vector>
#include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

@@ -28,14 +29,15 @@ class TfliteArgminParser : public TfliteNodeParser {
public:
TfliteArgminParser() : TfliteNodeParser("Argmin") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore

#endif // PREDICT_TFLITE_ARGMIN_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H

+ 84
- 136
mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc View File

@@ -18,14 +18,17 @@
#include <vector>
#include <memory>
#include <string>
#include <map>

namespace mindspore {
namespace lite {
STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -37,124 +40,72 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
}

std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-");
Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();

if (std::strcmp(node_name, "Add") == 0
|| std::strcmp(node_name, "Sub") == 0
|| std::strcmp(node_name, "Mul") == 0
|| std::strcmp(node_name, "Div") == 0) {
auto x_index = tfliteOp->inputs[0];
const auto &x_tensor = tfliteTensors[x_index];
if (x_tensor == nullptr) {
MS_LOG(ERROR) << "the first input is null";
if (std::strcmp(node_name, "Add") == 0) {
MS_LOG(DEBUG) << "parse TfliteAddParser";
std::unique_ptr<schema::AddT> attr(new schema::AddT());
const auto &tfliteAttr = tflite_op->builtin_options.AsAddOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
auto &x_data = tfliteModelBuffer.at(x_tensor->buffer);
if (x_data == nullptr) {
MS_LOG(ERROR) << "the data of the first input is null";
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Add;
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Sub") == 0) {
MS_LOG(DEBUG) << "parse TfliteSubParser";
std::unique_ptr<schema::SubT> attr(new schema::SubT());
const auto &tfliteAttr = tflite_op->builtin_options.AsSubOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
if (!x_data->data.empty()) {
std::vector<tflite::TensorT *> x_tensors{x_tensor.get()};
if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse the first tensor failed";
return RET_ERROR;
}
}

auto y_index = tfliteOp->inputs[1];
const auto &y_tensor = tfliteTensors[y_index];
if (y_tensor == nullptr) {
MS_LOG(ERROR) << "the second input is null";
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Sub;
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Mul") == 0) {
MS_LOG(DEBUG) << "parse TfliteMulParser";
std::unique_ptr<schema::MulT> attr(new schema::MulT());
const auto &tfliteAttr = tflite_op->builtin_options.AsMulOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
auto &y_data = tfliteModelBuffer.at(y_tensor->buffer);
if (y_data == nullptr) {
MS_LOG(ERROR) << "the data of the second input is null";
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Mul;
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Div") == 0) {
MS_LOG(DEBUG) << "parse TfliteDivParser";
std::unique_ptr<schema::DivT> attr(new schema::DivT());
const auto &tfliteAttr = tflite_op->builtin_options.AsDivOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
if (!y_data->data.empty()) {
std::vector<tflite::TensorT *> y_tensors{y_tensor.get()};
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse the second tensor failed";
return RET_ERROR;
}
}

if (std::strcmp(node_name, "Add") == 0) {
MS_LOG(DEBUG) << "parse TfliteAddParser";
std::unique_ptr<schema::AddT> attr(new schema::AddT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsAddOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Add;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Sub") == 0) {
MS_LOG(DEBUG) << "parse TfliteSubParser";
std::unique_ptr<schema::SubT> attr(new schema::SubT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsSubOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Sub;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Mul") == 0) {
MS_LOG(DEBUG) << "parse TfliteMulParser";
std::unique_ptr<schema::MulT> attr(new schema::MulT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsMulOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Mul;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Div") == 0) {
MS_LOG(DEBUG) << "parse TfliteDivParser";
std::unique_ptr<schema::DivT> attr(new schema::DivT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsDivOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Div;
op->primitive->value.value = attr.release();
return RET_OK;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Div;
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "FloorDiv") == 0) {
MS_LOG(DEBUG) << "parse TfliteFloorDivParser";
std::unique_ptr<schema::FloorDivT> attr(new schema::FloorDivT());
op->primitive->value.type = schema::PrimitiveType_FloorDiv;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "FloorMod") == 0) {
MS_LOG(DEBUG) << "parse TfliteFloorModParser";
std::unique_ptr<schema::FloorModT> attr(new schema::FloorModT());
op->primitive->value.type = schema::PrimitiveType_FloorMod;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "RealDiv") == 0) {
MS_LOG(DEBUG) << "parse TfliteRealDivParser";
std::unique_ptr<schema::RealDivT> attr(new schema::RealDivT());
op->primitive->value.type = schema::PrimitiveType_RealDiv;
op->primitive->value.type = schema::PrimitiveType_Div;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "SquaredDifference") == 0) {
MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser";
std::unique_ptr<schema::SquaredDifferenceT> attr(new schema::SquaredDifferenceT());
op->primitive->value.type = schema::PrimitiveType_SquaredDifference;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Pow") == 0) {
MS_LOG(DEBUG) << "parse TflitePowParser";
std::unique_ptr<schema::PowerT> attr(new schema::PowerT());
@@ -163,31 +114,35 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
attr->shift = 0.0f;
op->primitive->value.type = schema::PrimitiveType_Power;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Maximum") == 0) {
MS_LOG(DEBUG) << "parse TfliteMaximumParser";
std::unique_ptr<schema::MaximumT> attr(new schema::MaximumT());
op->primitive->value.type = schema::PrimitiveType_Maximum;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Minimum") == 0) {
MS_LOG(DEBUG) << "parse TfliteMinimumParser";
std::unique_ptr<schema::MinimumT> attr(new schema::MinimumT());
op->primitive->value.type = schema::PrimitiveType_Minimum;
op->primitive->value.value = attr.release();
return RET_OK;
} else {
MS_LOG(ERROR) << "wrong op type";
return RET_ERROR;
}

// set input
for (int i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}

STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -199,85 +154,79 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
}

std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-");
Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Abs") == 0) {
MS_LOG(DEBUG) << "parse TfliteAbsParser";
std::unique_ptr<schema::AbsT> attr(new schema::AbsT());
op->primitive->value.type = schema::PrimitiveType_Abs;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Exp") == 0) {
MS_LOG(DEBUG) << "parse TfliteExpParser";
std::unique_ptr<schema::ExpT> attr(new schema::ExpT());
op->primitive->value.type = schema::PrimitiveType_Exp;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Sqrt") == 0) {
MS_LOG(DEBUG) << "parse TfliteSqrtParser";
std::unique_ptr<schema::SqrtT> attr(new schema::SqrtT());
op->primitive->value.type = schema::PrimitiveType_Sqrt;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Rsqrt") == 0) {
MS_LOG(DEBUG) << "parse TfliteRsqrtParser";
std::unique_ptr<schema::RsqrtT> attr(new schema::RsqrtT());
op->primitive->value.type = schema::PrimitiveType_Rsqrt;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Square") == 0) {
MS_LOG(DEBUG) << "parse TfliteSquareParser";
std::unique_ptr<schema::SquareT> attr(new schema::SquareT());
op->primitive->value.type = schema::PrimitiveType_Square;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Sin") == 0) {
MS_LOG(DEBUG) << "parse TfliteSinParser";
std::unique_ptr<schema::SinT> attr(new schema::SinT());
op->primitive->value.type = schema::PrimitiveType_Sin;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Cos") == 0) {
MS_LOG(DEBUG) << "parse TfliteCosParser";
std::unique_ptr<schema::CosT> attr(new schema::CosT());
op->primitive->value.type = schema::PrimitiveType_Cos;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Log") == 0) {
MS_LOG(DEBUG) << "parse TfliteLogParser";
std::unique_ptr<schema::LogT> attr(new schema::LogT());
op->primitive->value.type = schema::PrimitiveType_Log;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Round") == 0) {
MS_LOG(DEBUG) << "parse TfliteRoundParser";
std::unique_ptr<schema::RoundT> attr(new schema::RoundT());
op->primitive->value.type = schema::PrimitiveType_Round;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Ceil") == 0) {
MS_LOG(DEBUG) << "parse TfliteCeilParser";
std::unique_ptr<schema::CeilT> attr(new schema::CeilT());
op->primitive->value.type = schema::PrimitiveType_Ceil;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "flOOR") == 0) {
MS_LOG(DEBUG) << "parse TfliteFloorParser";
std::unique_ptr<schema::FloorT> attr(new schema::FloorT());
op->primitive->value.type = schema::PrimitiveType_Floor;
op->primitive->value.value = attr.release();
return RET_OK;
} else {
MS_LOG(ERROR) << "wrong op type";
return RET_ERROR;
}

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}

STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -289,48 +238,47 @@ STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
}

std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-");
Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Equal") == 0) {
MS_LOG(DEBUG) << "parse TfliteEqualParser";
std::unique_ptr<schema::EqualT> attr(new schema::EqualT());
op->primitive->value.type = schema::PrimitiveType_Equal;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "NotEqual") == 0) {
MS_LOG(DEBUG) << "parse TfliteNotEqualParser";
std::unique_ptr<schema::NotEqualT> attr(new schema::NotEqualT());
op->primitive->value.type = schema::PrimitiveType_NotEqual;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Greater") == 0) {
MS_LOG(DEBUG) << "parse TfliteGreaterParser";
std::unique_ptr<schema::GreaterT> attr(new schema::GreaterT());
op->primitive->value.type = schema::PrimitiveType_Greater;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "GreaterEqual") == 0) {
MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser";
std::unique_ptr<schema::GreaterEqualT> attr(new schema::GreaterEqualT());
op->primitive->value.type = schema::PrimitiveType_GreaterEqual;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Less") == 0) {
MS_LOG(DEBUG) << "parse TfliteLessParser";
std::unique_ptr<schema::LessT> attr(new schema::LessT());
op->primitive->value.type = schema::PrimitiveType_Less;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "LessEqual") == 0) {
MS_LOG(DEBUG) << "parse TfliteLessEqualParser";
std::unique_ptr<schema::LessEqualT> attr(new schema::LessEqualT());
op->primitive->value.type = schema::PrimitiveType_LessEqual;
op->primitive->value.value = attr.release();
return RET_OK;
} else {
MS_LOG(ERROR) << "wrong op type";
return RET_ERROR;
}

for (int i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}

TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser());


+ 25
- 18
mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License.
*/

#ifndef PREDICT_TFLITE_MATH_PARSER_H
#define PREDICT_TFLITE_MATH_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H

#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

@@ -29,11 +30,13 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser {
public:
TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};

class TfliteAddParser : public TfliteDoubleInputOpParser {
@@ -96,11 +99,13 @@ class TfliteSingleInputOpParser : public TfliteNodeParser {
public:
TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};

class TfliteAbsParser : public TfliteSingleInputOpParser {
@@ -163,11 +168,13 @@ class TfliteCompareOpParser : public TfliteNodeParser {
public:
TfliteCompareOpParser() : TfliteNodeParser("node_name") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};

class TfliteEqualParser : public TfliteCompareOpParser {
@@ -203,5 +210,5 @@ class TfliteLessEqualParser : public TfliteCompareOpParser {
} // namespace lite
} // namespace mindspore

#endif // PREDICT_TFLITE_MATH_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H


+ 16
- 11
mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc View File

@@ -19,14 +19,17 @@
#include <vector>
#include <memory>
#include <string>
#include <map>

namespace mindspore {
namespace lite {
STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -38,30 +41,32 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT>
}

std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-");
Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "BatchToSpace") == 0) {
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser";
} else if (std::strcmp(node_name, "BatchToSpaceND") == 0) {
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser";
// in tflite
// blockShape should be a 1D tensor with dimension [spatial_dims_num]
// crops should be a 2D tensor with dimension [spatial_dims_num, 2]
}

std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT());

if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) {
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->blockShape)) {
MS_LOG(ERROR) << "get batchToSpace -> blockShape failed";
return RET_ERROR;
}
if (GetTfliteData(tfliteOp->inputs[2], tfliteTensors, tfliteModelBuffer, attr->crops)) {
if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->crops)) {
MS_LOG(ERROR) << "get batchToSpace -> crops failed";
return RET_ERROR;
}

op->primitive->value.type = schema::PrimitiveType_BatchToSpace;
op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}



+ 8
- 5
mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License.
*/

#ifndef LITE_TFLITE_BATCH_TO_SPACE_PARSER_H
#define LITE_TFLITE_BATCH_TO_SPACE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H

#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

@@ -31,8 +32,10 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantized_model) override;
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};

class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser {
@@ -43,4 +46,4 @@ class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser {
} // namespace lite
} // namespace mindspore

#endif // LITE_TFLITE_BATCH_TO_SPACE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H

+ 16
- 7
mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc View File

@@ -18,14 +18,19 @@
#include "tools/converter/parser/tflite/tflite_broadcast_to_parser.h"
#include <vector>
#include <memory>
#include <map>

namespace mindspore {
namespace lite {
STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteBroadcastToParser";

if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -36,16 +41,20 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "parse TfliteBroadcastToParser";
std::unique_ptr<schema::BroadcastToT> attr(new schema::BroadcastToT());

if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dst_shape)) {
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dst_shape)) {
MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed";
return RET_ERROR;
}

op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}



+ 8
- 6
mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License.
*/

#ifndef LITE_TFLITE_BROADCAST_TO_PARSER_H
#define LITE_TFLITE_BROADCAST_TO_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H

#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

@@ -31,11 +32,12 @@ class TfliteBroadcastToParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_TFLITE_BROADCAST_TO_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H

+ 19
- 12
mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc View File

@@ -14,18 +14,22 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/parser/tflite/tflite_cast_parser.h"
#include <vector>
#include <memory>
#include <map>

namespace mindspore {
namespace lite {
STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteCastParser";

if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -36,25 +40,28 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "parse TfliteCastParser";
std::unique_ptr<schema::CastT> attr(new schema::CastT());

const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]];
const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]];
if (in_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null";
return RET_NULL_PTR;
}
attr->srcT = dtype_map[in_tensor->type];

const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]];
attr->srcT = GetTfliteDataType(in_tensor->type);
const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]];
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null";
return RET_NULL_PTR;
}
attr->dstT = dtype_map[out_tensor->type];
attr->dstT = GetTfliteDataType(out_tensor->type);

op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}



+ 8
- 6
mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License.
*/

#ifndef LITE_TFLITE_CAST_PARSER_
#define LITE_TFLITE_CAST_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H

#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

@@ -31,11 +32,12 @@ class TfliteCastParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_TFLITE_CAST_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H

+ 20
- 9
mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc View File

@@ -17,14 +17,20 @@
#include "tools/converter/parser/tflite/tflite_concat_parser.h"
#include <vector>
#include <memory>
#include <map>

namespace mindspore {
namespace lite {
STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteConcatParser";

// set attr
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -35,20 +41,25 @@ STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "parse TfliteConcatParser";
std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT());

const auto &tfliteAttr = tfliteOp->builtin_options.AsConcatenationOptions();
const auto &tfliteAttr = tflite_op->builtin_options.AsConcatenationOptions();
if (tfliteAttr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->axis = tfliteAttr->axis;

attr->n = tfliteOp->inputs.size();
attr->n = tflite_op->inputs.size();

op->primitive->value.type = schema::PrimitiveType_Concat;
op->primitive->value.value = attr.release();

for (int i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}



+ 11
- 9
mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License.
*/

#ifndef PREDICT_TFLITE_CONCAT_PARSER_H
#define PREDICT_TFLITE_CONCAT_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H

#include <vector>
#include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

@@ -28,15 +29,16 @@ class TfliteConcatParser : public TfliteNodeParser {
public:
TfliteConcatParser() : TfliteNodeParser("Concat") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore

#endif // PREDICT_TFLITE_CONCAT_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H


+ 48
- 42
mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc View File

@@ -17,14 +17,19 @@
#include "tools/converter/parser/tflite/tflite_conv_parser.h"
#include <vector>
#include <memory>
#include <map>

namespace mindspore {
namespace lite {
STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteConvParser";

if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -35,60 +40,61 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "parse TfliteConvParser";
std::unique_ptr<schema::Conv2DT> attr(new schema::Conv2DT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsConv2DOptions();
if (tfliteAttr == nullptr) {
const auto &tflite_attr = tflite_op->builtin_options.AsConv2DOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->group = 1;
attr->strideW = tfliteAttr->stride_w;
attr->strideH = tfliteAttr->stride_h;
attr->dilateH = tfliteAttr->dilation_h_factor;
attr->dilateW = tfliteAttr->dilation_w_factor;
attr->padMode = GetPadMode(tfliteAttr->padding);
attr->strideW = tflite_attr->stride_w;
attr->strideH = tflite_attr->stride_h;
attr->dilateH = tflite_attr->dilation_h_factor;
attr->dilateW = tflite_attr->dilation_w_factor;
attr->padMode = GetPadMode(tflite_attr->padding);
attr->format = schema::Format_NHWC;
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
attr->hasBias = true;

// get the conv op weight tensor
auto weight_index = tfliteOp->inputs[1];
const auto &weight_tensor = tfliteTensors[weight_index];
auto weight_index = tflite_op->inputs[1];
const auto &weight_tensor = tflite_tensors[weight_index];
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "weight_tensor is null";
MS_LOG(ERROR) << "the weight tensor is null";
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
auto weight_shape = weight_tensor->shape;
attr->channelIn = weight_shape[KHWC_C];
attr->channelOut = weight_shape[KHWC_K];
attr->kernelW = weight_shape[KHWC_W];
attr->kernelH = weight_shape[KHWC_H];

// get the conv op bias tensor
if (tfliteOp->inputs.size() == 3) {
attr->hasBias = true;
auto bias_index = tfliteOp->inputs[2];
const auto &bias_tensor = tfliteTensors[bias_index];
if (bias_tensor == nullptr) {
MS_LOG(ERROR) << "bias_tensor is null";
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}
}
attr->channelIn = weight_shape[3];
attr->channelOut = weight_shape[0];
attr->kernelH = weight_shape[1];
attr->kernelW = weight_shape[2];

// calculate pad params
auto data_index = tflite_op->inputs[0];
const auto &data_tensor = tflite_tensors[data_index];
std::vector<int> params;
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH,
attr->strideW, attr->kernelH, attr->kernelW, &params) != RET_OK) {
MS_LOG(ERROR) << "get padding params failed";
return RET_ERROR;
} else {
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);
attr->padRight = params.at(3);
}

op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}



+ 11
- 9
mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License.
*/

#ifndef PREDICT_TFLITE_CONV_PARSER_H
#define PREDICT_TFLITE_CONV_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H

#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

@@ -28,15 +29,16 @@ class TfliteConvParser : public TfliteNodeParser {
public:
TfliteConvParser() : TfliteNodeParser("Conv2D") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore

#endif // PREDICT_TFLITE_CONV_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H


+ 1
- 0
mindspore/lite/tools/converter/parser/tflite/tflite_converter.h View File

@@ -19,6 +19,7 @@

#include <string>
#include <memory>
#include <map>
#include "tools/converter/converter.h"
#include "tools/converter/parser/tflite/tflite_model_parser.h"
#include "tools/converter/graphdef_transform.h"


+ 44
- 18
mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc View File

@@ -17,14 +17,19 @@
#include "tools/converter/parser/tflite/tflite_deconv_parser.h"
#include <vector>
#include <memory>
#include <map>

namespace mindspore {
namespace lite {
STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser";

if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -35,11 +40,10 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser";
std::unique_ptr<schema::DeConv2DT> attr(new schema::DeConv2DT());
const auto &tflite_attr = tfliteOp->builtin_options.AsTransposeConvOptions();
const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str();
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}

@@ -50,26 +54,48 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
attr->dilateW = 1;
attr->padMode = GetPadMode(tflite_attr->padding);
attr->format = schema::Format_NHWC;
attr->activationType = schema::ActivationType_NO_ACTIVATION;
attr->hasBias = true;

// get the conv op weight tensor
auto weight_index = tfliteOp->inputs[1];
const auto &weight_tensor = tfliteTensors[weight_index];
auto weight_index = tflite_op->inputs[1];
const auto &weight_tensor = tflite_tensors[weight_index];
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "weight_tensor is null";
MS_LOG(ERROR) << "the weight tensor is null";
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
auto weight_shape = weight_tensor->shape;
attr->channelIn = weight_shape[3];
attr->channelOut = weight_shape[0];
attr->kernelH = weight_shape[1];
attr->kernelW = weight_shape[2];

// calculate pad params
auto data_index = tflite_op->inputs[2];
const auto &data_tensor = tflite_tensors[data_index];
std::vector<int> params;
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH,
attr->strideW, attr->kernelH, attr->kernelW, &params) != RET_OK) {
MS_LOG(ERROR) << "get padding params failed";
return RET_ERROR;
} else {
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);
attr->padRight = params.at(3);
}
auto weight_shape = weight_tensor->shape;
attr->channelIn = weight_shape[CHWK_K];
attr->channelOut = weight_shape[CHWK_C];
attr->kernelW = weight_shape[CHWK_W];
attr->kernelH = weight_shape[CHWK_H];

op->primitive->value.type = schema::PrimitiveType_DeConv2D;
op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}



+ 8
- 6
mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License.
*/

#ifndef PREDICT_TFLITE_DECONV_PARSER_H
#define PREDICT_TFLITE_DECONV_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H

#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

@@ -31,11 +32,12 @@ class TfliteDeConvParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_op_set, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore

#endif // PREDICT_TFLITE_DECONV_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H

+ 16
- 8
mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc View File

@@ -18,14 +18,19 @@
#include "tools/converter/parser/tflite/tflite_depth_to_space_parser.h"
#include <vector>
#include <memory>
#include <map>

namespace mindspore {
namespace lite {
STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser";

if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -36,20 +41,23 @@ STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT>
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser";
std::unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT());

const auto &tflite_attr = tfliteOp->builtin_options.AsDepthToSpaceOptions();
const auto &tflite_attr = tflite_op->builtin_options.AsDepthToSpaceOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str();
return RET_NULL_PTR;
}
attr->blockSize = tflite_attr->block_size;

attr->format = schema::Format_NHWC;

op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}



+ 8
- 6
mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License.
*/

#ifndef LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H
#define LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H

#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

@@ -31,11 +32,12 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H

+ 43
- 87
mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc View File

@@ -17,65 +17,22 @@
#include "tools/converter/parser/tflite/tflite_depthwise_conv_parser.h"
#include <vector>
#include <memory>
#include <map>
#include "tools/common/node_util.h"

namespace mindspore {
namespace lite {
STATUS TfliteDepthwiseConv2DParser::ParseGroupDepthwiseConv(schema::CNodeT *op,
const std::unique_ptr<schema::DepthwiseConv2DT> &attr,
const std::unique_ptr<tflite::TensorT> &weightTensor,
TensorCache *tensor_cache) {
std::unique_ptr<schema::Conv2DT> convAttr(new schema::Conv2DT);

convAttr->format = attr->format;
convAttr->channelIn = attr->channelIn;
convAttr->channelOut = attr->channelIn * attr->channelMultiplier;
convAttr->kernelH = attr->kernelH;
convAttr->kernelW = attr->kernelW;
convAttr->strideH = attr->strideH;
convAttr->strideW = attr->strideW;
convAttr->padMode = attr->padMode;
convAttr->padUp = attr->padUp;
convAttr->padDown = attr->padDown;
convAttr->padLeft = attr->padLeft;
convAttr->padRight = attr->padRight;
convAttr->dilateH = attr->dilateH;
convAttr->dilateW = attr->dilateW;
convAttr->hasBias = attr->hasBias;
convAttr->activationType = attr->activationType;

auto weightTensorIndex = tensor_cache->FindTensor(weightTensor->name);
if (weightTensorIndex >= 0 && weightTensorIndex < tensor_cache->GetCachedTensor().size()) {
auto liteWeightTensor = tensor_cache->GetCachedTensor()[weightTensorIndex];
if (liteWeightTensor->dataType == TypeId::kNumberTypeUInt8) {
// convert weight format KHWC -> CHWK
auto status = TransFilterFormat<uint8_t>(liteWeightTensor, kKHWC2CHWK);
if (status != RET_OK) {
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
return RET_ERROR;
}
}

if (liteWeightTensor->dataType == kNumberTypeFloat32 || liteWeightTensor->dataType == kNumberTypeFloat) {
// convert weight format KHWC -> CHWK
auto status = TransFilterFormat<float>(liteWeightTensor, kKHWC2CHWK);
if (status != RET_OK) {
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
return RET_ERROR;
}
}
}

op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = convAttr.release();
return RET_OK;
}

STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser";

if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -86,7 +43,6 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser";
std::unique_ptr<schema::DepthwiseConv2DT> attr(new schema::DepthwiseConv2DT());
const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions();
if (tflite_attr == nullptr) {
@@ -100,15 +56,20 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
attr->padMode = GetPadMode(tflite_attr->padding);
attr->format = schema::Format_NHWC;
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
// get the conv op weight tensor
auto input_index = tflite_op->inputs[0];
const auto &input_tenosr = tflite_tensors[input_index];
if (input_tenosr == nullptr) {
MS_LOG(ERROR) << "the first input is null";
attr->hasBias = true;
attr->channelMultiplier = tflite_attr->depth_multiplier;

// get the data tensor
auto data_index = tflite_op->inputs[1];
const auto &data_tensor = tflite_tensors[data_index];
if (data_tensor == nullptr) {
MS_LOG(ERROR) << "the data tensor is null";
return RET_NULL_PTR;
}
auto input_shape = input_tenosr->shape;
auto data_shape = data_tensor->shape;
attr->channelIn = data_shape[3];

// get the weight tensor
auto weight_index = tflite_op->inputs[1];
const auto &weight_tensor = tflite_tensors[weight_index];
if (weight_tensor == nullptr) {
@@ -116,38 +77,33 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
return RET_NULL_PTR;
}
auto weight_shape = weight_tensor->shape;
attr->channelIn = input_shape[KHWC_C];
attr->channelMultiplier = tflite_attr->depth_multiplier;
attr->kernelH = weight_shape[KHWC_H];
attr->kernelW = weight_shape[KHWC_W];

std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};

if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse weight failed";
attr->kernelH = weight_shape[1];
attr->kernelW = weight_shape[2];

// calculate pad params
std::vector<int> params;
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW,
attr->kernelH, attr->kernelW, &params) != RET_OK) {
MS_LOG(ERROR) << "get padding params failed";
return RET_ERROR;
}

if (tflite_op->inputs.size() == 3) {
attr->hasBias = true;
auto bias_index = tflite_op->inputs[2];
const auto &bias_tensor = tflite_tensors[bias_index];
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}
}

if (attr->channelMultiplier > 1) {
if (RET_OK != ParseGroupDepthwiseConv(op, attr, weight_tensor, tensor_cache)) {
MS_LOG(ERROR) << "Parse Group DepthwiseConv failed";
return RET_ERROR;
}
} else {
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
op->primitive->value.value = attr.release();
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);
attr->padRight = params.at(3);
}

op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}



+ 11
- 14
mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License.
*/

#ifndef PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H
#define PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTHWISE_CONV_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTHWISE_CONV_PARSER_H

#include <vector>
#include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

@@ -28,20 +29,16 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser {
public:
TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) override;

private:
STATUS ParseGroupDepthwiseConv(schema::CNodeT *op,
const std::unique_ptr<schema::DepthwiseConv2DT> &attr,
const std::unique_ptr<tflite::TensorT> &weightTensor,
TensorCache *tensor_cache);
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore

#endif // PREDICT_TFLITE_CONV_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H


+ 22
- 19
mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc View File

@@ -16,15 +16,20 @@
#include "tools/converter/parser/tflite/tflite_dequantize_parser.h"
#include <vector>
#include <memory>
#include <map>
#include "tools/common/node_util.h"

namespace mindspore {
namespace lite {
STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteDequantizeNParser";

if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -35,32 +40,30 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "parse TfliteDequantizeNParser";
std::unique_ptr<schema::CastT> attr(new schema::CastT);

// get the dequantize input tensor
const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]];
const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]];
if (in_tensor == nullptr) {
MS_LOG(ERROR) << "weight_tensor is null";
MS_LOG(ERROR) << "input tensor is null";
return RET_NULL_PTR;
}
attr->srcT = dtype_map[in_tensor->type];

const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]];
attr->srcT = GetTfliteDataType(in_tensor->type);
const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]];
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null";
MS_LOG(ERROR) << "output tensor is null";
return RET_NULL_PTR;
}
attr->dstT = dtype_map[out_tensor->type];
std::vector<tflite::TensorT *> weight_tensors{in_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
attr->dstT = GetTfliteDataType(out_tensor->type);

op->primitive->value.type = schema::PrimitiveType_Fp16Cast;
op->primitive->value.value = attr.release();
return 0;

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}

TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser());


+ 11
- 8
mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h View File

@@ -13,11 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TFLITE_DEQUANTIZE_PARSER_H
#define LITE_TFLITE_DEQUANTIZE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H

#include <vector>
#include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

@@ -27,13 +28,15 @@ class TfliteDequantizeParser : public TfliteNodeParser {
public:
TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_TFLITE_DEQUANTIZE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H

+ 9
- 8
mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc View File

@@ -17,16 +17,17 @@
#include "tools/converter/parser/tflite/tflite_expand_dims_parser.h"
#include <vector>
#include <memory>
#include <map>

namespace mindspore {
namespace lite {
STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -40,7 +41,7 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
MS_LOG(DEBUG) << "parse TfliteExpandDimsParser";
std::unique_ptr<schema::ExpandDimsT> attr(new schema::ExpandDimsT());

const auto &tflite_attr = tfliteOp->builtin_options.AsExpandDimsOptions();
const auto &tflite_attr = tflite_op->builtin_options.AsExpandDimsOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;


+ 11
- 9
mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License.
*/

#ifndef PREDICT_TFLITE_EXPAND_DIMS_PARSER_H
#define PREDICT_TFLITE_EXPAND_DIMS_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H

#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

@@ -28,15 +29,16 @@ class TfliteExpandDimsParser : public TfliteNodeParser {
public:
TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore

#endif // PREDICT_TFLITE_EXPAND_DIMS_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H


+ 0
- 75
mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc View File

@@ -1,75 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_fakequant_parser.h"
#include <vector>
#include <memory>

namespace mindspore {
namespace lite {
STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser";
std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT());

auto weight_index = tfliteOp->inputs[1];
const auto &weight_tensor = tfliteTensors[weight_index];
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "weight_tensor is null";
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}

if (tfliteOp->inputs.size() == 3) {
attr->hasBias = true;
auto bias_index = tfliteOp->inputs[2];
const auto &bias_tensor = tfliteTensors[bias_index];
if (bias_tensor == nullptr) {
MS_LOG(ERROR) << "bias_tensor is null";
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}
}
attr->axis = 1;

op->primitive->value.type = schema::PrimitiveType_FullConnection;
op->primitive->value.value = attr.release();
return RET_OK;
}

TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFakeQuantParser());
} // namespace lite
} // namespace mindspore

+ 0
- 39
mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h View File

@@ -1,39 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TFLITE_FAKEQUANT_PARSER_H
#define LITE_TFLITE_FAKEQUANT_PARSER_H

#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

namespace mindspore {
namespace lite {
class TfliteFakeQuantParser : public TfliteNodeParser {
public:
TfliteFakeQuantParser() : TfliteNodeParser("FakeQuant") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_TFLITE_FAKEQUANT_PARSER_H

+ 19
- 12
mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc View File

@@ -14,19 +14,22 @@
* limitations under the License.
*/

#include "tools/converter/parser/tflite/tflite_fill_parser.h"
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_fill_parser.h"
#include <map>

namespace mindspore {
namespace lite {
STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteFillParser";

if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@@ -37,18 +40,22 @@ STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "parse TfliteFillParser";
std::unique_ptr<schema::FillT> attr(new schema::FillT());

if (tfliteOp->inputs.size() > 1) {
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dims)) {
MS_LOG(ERROR) << "get Fill -> dims failed";
if (tflite_op->inputs.size() > 1) {
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dims)) {
MS_LOG(ERROR) << "get fill -> dims failed";
return RET_ERROR;
}
}

op->primitive->value.type = schema::PrimitiveType_Fill;
op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}



Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save