Browse Source

refactor tflite parsers.

append ut

refactor tflite parsers

modify tflite parser, ut and model

supplement caffe flatten parser

fix the weight tensor format of deconv bug

fix bug when idx=-1

fix the weight tensor format of depthConv bug.
tags/v0.7.0-beta
lyvette 5 years ago
parent
commit
123c2024a5
100 changed files with 1319 additions and 1090 deletions
  1. +14
    -14
      mindspore/lite/test/run_test.sh
  2. +0
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add.tflite
  3. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add2.tflite
  4. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add3.tflite
  5. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/argmax.tflite
  6. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/concat.tflite
  7. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/conv.tflite
  8. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/deconv.tflite
  9. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv1.tflite
  10. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv2.tflite
  11. +0
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div.tflite
  12. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div2.tflite
  13. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div3.tflite
  14. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/hardswish.tflite
  15. +0
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logistic.tflite
  16. +0
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mean_pooling.tflite
  17. +0
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul.tflite
  18. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul2.tflite
  19. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul3.tflite
  20. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/realdiv.tflite
  21. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/slice.tflite
  22. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/stack.tflite
  23. +0
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub.tflite
  24. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub2.tflite
  25. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub3.tflite
  26. BIN
      mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/transpose.tflite
  27. +59
    -12
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_activation_parser_test.cc
  28. +2
    -4
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_addn_parser_test.cc
  29. +44
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmax_parser_test.cc
  30. +2
    -3
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmin_parser_test.cc
  31. +39
    -202
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_arithmetic_parser_test.cc
  32. +5
    -7
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_batch_to_space_nd_parser_test.cc
  33. +3
    -6
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cast_parser_test.cc
  34. +41
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_concat_parser_test.cc
  35. +57
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc
  36. +58
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc
  37. +3
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depth_to_space_parser_test.cc
  38. +92
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc
  39. +3
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_fill_parser_test.cc
  40. +2
    -3
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_nd_parser_test.cc
  41. +2
    -3
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_parser_test.cc
  42. +2
    -3
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_lrn_parser_test.cc
  43. +2
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_one_hot_parser_test.cc
  44. +2
    -4
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pad_parser_test.cc
  45. +2
    -10
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pooling_parser_test.cc
  46. +10
    -30
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_parser_test.cc
  47. +2
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reshape_parser_test.cc
  48. +4
    -8
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_resize_parser_test.cc
  49. +2
    -4
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_parser_test.cc
  50. +5
    -7
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_sequence_parser_test.cc
  51. +45
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_slice_parser_test.cc
  52. +3
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_softmax_parser_test.cc
  53. +5
    -7
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser_test.cc
  54. +3
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_depth_parser_test.cc
  55. +8
    -10
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sparse_to_dense_parser_test.cc
  56. +5
    -7
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_parser_test.cc
  57. +5
    -7
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_v_parser_test.cc
  58. +44
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_stack_parser_test.cc
  59. +13
    -15
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_strided_slice_parser_test.cc
  60. +3
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_tile_parser_test.cc
  61. +4
    -7
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc
  62. +43
    -0
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_transpose_parser_test.cc
  63. +3
    -4
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unique_parser_test.cc
  64. +3
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unstack_parser_test.cc
  65. +2
    -2
      mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc
  66. +8
    -3
      mindspore/lite/tools/converter/parser/caffe/caffe_flatten_parser.cc
  67. +68
    -24
      mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc
  68. +32
    -14
      mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h
  69. +20
    -8
      mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc
  70. +8
    -6
      mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h
  71. +18
    -10
      mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc
  72. +11
    -9
      mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h
  73. +20
    -9
      mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc
  74. +11
    -9
      mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h
  75. +84
    -136
      mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc
  76. +25
    -18
      mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h
  77. +16
    -11
      mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc
  78. +8
    -5
      mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h
  79. +16
    -7
      mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc
  80. +8
    -6
      mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h
  81. +19
    -12
      mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc
  82. +8
    -6
      mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h
  83. +20
    -9
      mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc
  84. +11
    -9
      mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h
  85. +48
    -42
      mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc
  86. +11
    -9
      mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h
  87. +1
    -0
      mindspore/lite/tools/converter/parser/tflite/tflite_converter.h
  88. +44
    -18
      mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc
  89. +8
    -6
      mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h
  90. +16
    -8
      mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc
  91. +8
    -6
      mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h
  92. +43
    -87
      mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc
  93. +11
    -14
      mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h
  94. +22
    -19
      mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc
  95. +11
    -8
      mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h
  96. +9
    -8
      mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc
  97. +11
    -9
      mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h
  98. +0
    -75
      mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc
  99. +0
    -39
      mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h
  100. +19
    -12
      mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc

+ 14
- 14
mindspore/lite/test/run_test.sh View File

@@ -12,19 +12,19 @@ cp -r ${CUR_DIR}/ut/tools/converter/parser/tflite/test_data/* ./
TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/ TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/
cp -fr $TEST_DATA_DIR/testPK ./data cp -fr $TEST_DATA_DIR/testPK ./data


#./lite-test --gtest_filter="*MindDataTestTensorDE*"
#./lite-test --gtest_filter="*MindDataTestEager*"
#
#./lite-test --gtest_filter="TestTfliteParser*"
#
#./lite-test --gtest_filter="*TestHebing*"
#
#./lite-test --gtest_filter=TestFcFp32*
#./lite-test --gtest_filter=TestConv1x1Fp32*
#./lite-test --gtest_filter=TestStrassenFp32*
#./lite-test --gtest_filter=TestDeConvolutionFp32*
#
#./lite-test --gtest_filter=TestPadInt8.*
#./lite-test --gtest_filter=TestDeconvInt8.*
./lite-test --gtest_filter="*MindDataTestTensorDE*"
./lite-test --gtest_filter="*MindDataTestEager*"
./lite-test --gtest_filter="TestTfliteParser*"
./lite-test --gtest_filter="*TestHebing*"
./lite-test --gtest_filter=TestFcFp32*
./lite-test --gtest_filter=TestConv1x1Fp32*
./lite-test --gtest_filter=TestStrassenFp32*
./lite-test --gtest_filter=TestDeConvolutionFp32*
./lite-test --gtest_filter=TestPadInt8.*
./lite-test --gtest_filter=TestDeconvInt8.*


./lite-test --gtest_filter="TestTfliteParser*" ./lite-test --gtest_filter="TestTfliteParser*"

mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add1.tflite → mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add2.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add3.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/argmax.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/concat.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/conv.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/deconv.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv1.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv2.tflite View File


mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div1.tflite → mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div2.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div3.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/hardswish.tflite View File


mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sigmoid.tflite → mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logistic.tflite View File


mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/avg_pooling.tflite → mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mean_pooling.tflite View File


mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul1.tflite → mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul2.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul3.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/realdiv.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/slice.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/stack.tflite View File


mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub1.tflite → mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub2.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub3.tflite View File


BIN
mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/transpose.tflite View File


+ 59
- 12
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_activation_parser_test.cc View File

@@ -31,6 +31,12 @@ TEST_F(TestTfliteParserRelu, OpType) {
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
} }


TEST_F(TestTfliteParserRelu, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
ASSERT_EQ(val->type, schema::ActivationType_RELU);
}

class TestTfliteParserRelu6 : public TestTfliteParser { class TestTfliteParserRelu6 : public TestTfliteParser {
public: public:
TestTfliteParserRelu6() = default; TestTfliteParserRelu6() = default;
@@ -43,6 +49,12 @@ TEST_F(TestTfliteParserRelu6, OpType) {
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
} }


TEST_F(TestTfliteParserRelu6, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
ASSERT_EQ(val->type, schema::ActivationType_RELU6);
}

class TestTfliteParserTanh : public TestTfliteParser { class TestTfliteParserTanh : public TestTfliteParser {
public: public:
TestTfliteParserTanh() = default; TestTfliteParserTanh() = default;
@@ -55,7 +67,45 @@ TEST_F(TestTfliteParserTanh, OpType) {
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
} }


// logistic
TEST_F(TestTfliteParserTanh, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
ASSERT_EQ(val->type, schema::ActivationType_TANH);
}

class TestTfliteParserLogistic : public TestTfliteParser {
public:
TestTfliteParserLogistic() = default;
void SetUp() override { meta_graph = LoadAndConvert("./logistic.tflite", ""); }
};

TEST_F(TestTfliteParserLogistic, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
}
TEST_F(TestTfliteParserLogistic, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
ASSERT_EQ(val->type, schema::ActivationType_SIGMOID);
}

class TestTfliteParserHardSwish : public TestTfliteParser {
public:
TestTfliteParserHardSwish() = default;
void SetUp() override { meta_graph = LoadAndConvert("./hardswish.tflite", ""); }
};

TEST_F(TestTfliteParserHardSwish, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
}
TEST_F(TestTfliteParserHardSwish, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
ASSERT_EQ(val->type, schema::ActivationType_SIGMOID);
}


class TestTfliteParserPrelu : public TestTfliteParser { class TestTfliteParserPrelu : public TestTfliteParser {
public: public:
@@ -73,12 +123,11 @@ TEST_F(TestTfliteParserPrelu, OpType) {
} }


TEST_F(TestTfliteParserPrelu, AttrValue) { TEST_F(TestTfliteParserPrelu, AttrValue) {
std::vector<float> slope(20, 0);
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPrelu(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPrelu(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsPrelu()->slope, slope);
auto val = meta_graph->nodes.front()->primitive->value;
std::vector<float> slope(20, 0);
ASSERT_EQ(val.AsPrelu()->slope, slope);
ASSERT_EQ(val.type, schema::PrimitiveType_Prelu);
} }


class TestTfliteParserLeakyRelu : public TestTfliteParser { class TestTfliteParserLeakyRelu : public TestTfliteParser {
@@ -94,12 +143,10 @@ TEST_F(TestTfliteParserLeakyRelu, OpType) {
} }


TEST_F(TestTfliteParserLeakyRelu, AttrValue) { TEST_F(TestTfliteParserLeakyRelu, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

auto val = meta_graph->nodes.front()->primitive->value.AsLeakyReLU();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->negativeSlope, 0.20000000298023224);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsLeakyReLU(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value;
ASSERT_EQ(val.AsLeakyReLU()->negativeSlope, 0.20000000298023224);
ASSERT_EQ(val.type, schema::PrimitiveType_LeakyReLU);
} }


} // namespace mindspore } // namespace mindspore

+ 2
- 4
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_addn_parser_test.cc View File

@@ -35,10 +35,8 @@ TEST_F(TestTfliteParserAddN, OpType) {
} }


TEST_F(TestTfliteParserAddN, AttrValue) { TEST_F(TestTfliteParserAddN, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsAddN(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsAddN(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsAddN()->N, 4);
auto val = meta_graph->nodes.front()->primitive->value.AsAddN();
ASSERT_EQ(val->N, 4);
} }
} // namespace mindspore } // namespace mindspore

+ 44
- 0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmax_parser_test.cc View File

@@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"

namespace mindspore {
class TestTfliteParserArgmax : public TestTfliteParser {
public:
TestTfliteParserArgmax() = default;
void SetUp() override { meta_graph = LoadAndConvert("./argmax.tflite", ""); }
};

TEST_F(TestTfliteParserArgmax, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMax) << "wrong Op Type";
}

TEST_F(TestTfliteParserArgmax, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsArgMax(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsArgMax();
ASSERT_EQ(val->axis, 1);
ASSERT_EQ(val->topK, 1);
ASSERT_EQ(val->axisType, 1);
ASSERT_EQ(val->keepDims, false);
ASSERT_EQ(val->outMaxValue, false);
}

} // namespace mindspore

+ 2
- 3
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmin_parser_test.cc View File

@@ -25,15 +25,14 @@ class TestTfliteParserArgmin : public TestTfliteParser {
}; };


TEST_F(TestTfliteParserArgmin, OpType) { TEST_F(TestTfliteParserArgmin, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMin) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMin) << "wrong Op Type";
} }


TEST_F(TestTfliteParserArgmin, AttrValue) { TEST_F(TestTfliteParserArgmin, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsArgMin(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsArgMin(); auto val = meta_graph->nodes.front()->primitive->value.AsArgMin();
ASSERT_EQ(val->axis, 1); ASSERT_EQ(val->axis, 1);
ASSERT_EQ(val->topK, 1); ASSERT_EQ(val->topK, 1);


+ 39
- 202
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_arithmetic_parser_test.cc View File

@@ -19,234 +19,57 @@


namespace mindspore { namespace mindspore {
// doubleInputOp // doubleInputOp
class TestTfliteParserAdd1 : public TestTfliteParser {
class TestTfliteParserAdd : public TestTfliteParser {
public: public:
TestTfliteParserAdd1() = default;
void SetUp() override { meta_graph = LoadAndConvert("./add1.tflite", ""); }
TestTfliteParserAdd() = default;
void SetUp() override { meta_graph = LoadAndConvert("./add.tflite", ""); }
}; };


TEST_F(TestTfliteParserAdd1, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type";
}

TEST_F(TestTfliteParserAdd1, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserAdd2 : public TestTfliteParser {
public:
TestTfliteParserAdd2() = default;
void SetUp() override { meta_graph = LoadAndConvert("./add2.tflite", ""); }
};

TEST_F(TestTfliteParserAdd2, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type";
}

TEST_F(TestTfliteParserAdd2, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserAdd3 : public TestTfliteParser {
public:
TestTfliteParserAdd3() = default;
void SetUp() override { meta_graph = LoadAndConvert("./add3.tflite", ""); }
};

TEST_F(TestTfliteParserAdd3, OpType) {
TEST_F(TestTfliteParserAdd, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type";
} }


TEST_F(TestTfliteParserAdd3, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserSub1 : public TestTfliteParser {
public:
TestTfliteParserSub1() = default;
void SetUp() override { meta_graph = LoadAndConvert("./sub1.tflite", ""); }
};

TEST_F(TestTfliteParserSub1, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type";
}

TEST_F(TestTfliteParserSub1, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserSub2 : public TestTfliteParser {
public:
TestTfliteParserSub2() = default;
void SetUp() override { meta_graph = LoadAndConvert("./sub2.tflite", ""); }
};

TEST_F(TestTfliteParserSub2, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type";
}

TEST_F(TestTfliteParserSub2, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserSub3 : public TestTfliteParser {
class TestTfliteParserSub : public TestTfliteParser {
public: public:
TestTfliteParserSub3() = default;
void SetUp() override { meta_graph = LoadAndConvert("./sub3.tflite", ""); }
TestTfliteParserSub() = default;
void SetUp() override { meta_graph = LoadAndConvert("./sub.tflite", ""); }
}; };


TEST_F(TestTfliteParserSub3, OpType) {
TEST_F(TestTfliteParserSub, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type";
} }


TEST_F(TestTfliteParserSub3, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserMul1 : public TestTfliteParser {
public:
TestTfliteParserMul1() = default;
void SetUp() override { meta_graph = LoadAndConvert("./mul1.tflite", ""); }
};

TEST_F(TestTfliteParserMul1, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type";
}

TEST_F(TestTfliteParserMul1, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserMul2 : public TestTfliteParser {
class TestTfliteParserMul : public TestTfliteParser {
public: public:
TestTfliteParserMul2() = default;
void SetUp() override { meta_graph = LoadAndConvert("./mul2.tflite", ""); }
TestTfliteParserMul() = default;
void SetUp() override { meta_graph = LoadAndConvert("./mul.tflite", ""); }
}; };


TEST_F(TestTfliteParserMul2, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type";
}

TEST_F(TestTfliteParserMul2, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserMul3 : public TestTfliteParser {
public:
TestTfliteParserMul3() = default;
void SetUp() override { meta_graph = LoadAndConvert("./mul3.tflite", ""); }
};

TEST_F(TestTfliteParserMul3, OpType) {
TEST_F(TestTfliteParserMul, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type";
} }


TEST_F(TestTfliteParserMul3, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserDiv1 : public TestTfliteParser {
public:
TestTfliteParserDiv1() = default;
void SetUp() override { meta_graph = LoadAndConvert("./div1.tflite", ""); }
};

TEST_F(TestTfliteParserDiv1, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
}

TEST_F(TestTfliteParserDiv1, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserDiv2 : public TestTfliteParser {
class TestTfliteParserDiv : public TestTfliteParser {
public: public:
TestTfliteParserDiv2() = default;
void SetUp() override { meta_graph = LoadAndConvert("./div2.tflite", ""); }
TestTfliteParserDiv() = default;
void SetUp() override { meta_graph = LoadAndConvert("./div.tflite", ""); }
}; };


TEST_F(TestTfliteParserDiv2, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
}

TEST_F(TestTfliteParserDiv2, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserDiv3 : public TestTfliteParser {
public:
TestTfliteParserDiv3() = default;
void SetUp() override { meta_graph = LoadAndConvert("./div3.tflite", ""); }
};

TEST_F(TestTfliteParserDiv3, OpType) {
TEST_F(TestTfliteParserDiv, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
} }

TEST_F(TestTfliteParserDiv3, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}

class TestTfliteParserFloorDiv : public TestTfliteParser { class TestTfliteParserFloorDiv : public TestTfliteParser {
public: public:
TestTfliteParserFloorDiv() = default; TestTfliteParserFloorDiv() = default;
@@ -254,6 +77,7 @@ class TestTfliteParserFloorDiv : public TestTfliteParser {
}; };


TEST_F(TestTfliteParserFloorDiv, OpType) { TEST_F(TestTfliteParserFloorDiv, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorDiv) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorDiv) << "wrong Op Type";
@@ -266,12 +90,26 @@ class TestTfliteParserFloorMod : public TestTfliteParser {
}; };


TEST_F(TestTfliteParserFloorMod, OpType) { TEST_F(TestTfliteParserFloorMod, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorMod) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorMod) << "wrong Op Type";
} }


// realDiv
class TestTfliteParserRealDiv : public TestTfliteParser {
public:
TestTfliteParserRealDiv() = default;
void SetUp() override {
meta_graph = LoadAndConvert("./realdiv.tflite");
}
};

TEST_F(TestTfliteParserRealDiv, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
}


class TestTfliteParserSquaredDifference : public TestTfliteParser { class TestTfliteParserSquaredDifference : public TestTfliteParser {
public: public:
@@ -296,17 +134,15 @@ class TestTfliteParserPow : public TestTfliteParser {
}; };


TEST_F(TestTfliteParserPow, OpType) { TEST_F(TestTfliteParserPow, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Power) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Power) << "wrong Op Type";
} }


TEST_F(TestTfliteParserPow, AttrValue) { TEST_F(TestTfliteParserPow, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPower(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsPower(); auto val = meta_graph->nodes.front()->primitive->value.AsPower();

ASSERT_EQ(val->scale, 1.0); ASSERT_EQ(val->scale, 1.0);
ASSERT_EQ(val->shift, 0.0); ASSERT_EQ(val->shift, 0.0);
ASSERT_EQ(val->power, 0.0); ASSERT_EQ(val->power, 0.0);
@@ -477,6 +313,7 @@ class TestTfliteParserFloor : public TestTfliteParser {
}; };


TEST_F(TestTfliteParserFloor, OpType) { TEST_F(TestTfliteParserFloor, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Floor) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Floor) << "wrong Op Type";


+ 5
- 7
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_batch_to_space_nd_parser_test.cc View File

@@ -32,14 +32,12 @@ TEST_F(TestTfliteParserBatchToSpaceNd, OpType) {
} }


TEST_F(TestTfliteParserBatchToSpaceNd, AttrValue) { TEST_F(TestTfliteParserBatchToSpaceNd, AttrValue) {
const std::vector<int> blockShape{2, 2};
const std::vector<int> crops{0, 0, 2, 0};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsBatchToSpace(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsBatchToSpace(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->blockShape, blockShape);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->crops, crops);
auto val = meta_graph->nodes.front()->primitive->value.AsBatchToSpace();
const std::vector<int> blockShape = {2, 2};
ASSERT_EQ(val->blockShape, blockShape);
const std::vector<int> crops = {0, 0, 2, 0};
ASSERT_EQ(val->crops, crops);
} }


} // namespace mindspore } // namespace mindspore

+ 3
- 6
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cast_parser_test.cc View File

@@ -35,12 +35,9 @@ TEST_F(TestTfliteParserCast, OpType) {
} }


TEST_F(TestTfliteParserCast, AttrValue) { TEST_F(TestTfliteParserCast, AttrValue) {
// float32 --> int32
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsCast(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsCast(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->srcT, 43);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->dstT, 34);
auto val = meta_graph->nodes.front()->primitive->value.AsCast();
ASSERT_EQ(val->srcT, 43);
ASSERT_EQ(val->dstT, 34);
} }
} // namespace mindspore } // namespace mindspore

+ 41
- 0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_concat_parser_test.cc View File

@@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"

namespace mindspore {
class TestTfliteParserConcat : public TestTfliteParser {
public:
TestTfliteParserConcat() = default;
void SetUp() override { meta_graph = LoadAndConvert("./concat.tflite", ""); }
};

TEST_F(TestTfliteParserConcat, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Concat) << "wrong Op Type";
}

TEST_F(TestTfliteParserConcat, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConcat(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsConcat();
ASSERT_EQ(val->axis, 1);
ASSERT_EQ(val->n, 2);
}

} // namespace mindspore

+ 57
- 0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc View File

@@ -0,0 +1,57 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"

namespace mindspore {
class TestTfliteParserConv : public TestTfliteParser {
public:
TestTfliteParserConv() = default;
void SetUp() override { meta_graph = LoadAndConvert("./conv.tflite", ""); }
};

TEST_F(TestTfliteParserConv, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
}

TEST_F(TestTfliteParserConv, AttrValue) {
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr);
auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D();
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->group, 1);
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
ASSERT_EQ(val->hasBias, true);
ASSERT_EQ(val->channelIn, 1);
ASSERT_EQ(val->channelOut, 4);
ASSERT_EQ(val->kernelH, 3);
ASSERT_EQ(val->kernelW, 3);
ASSERT_EQ(val->strideH, 1);
ASSERT_EQ(val->strideW, 1);
ASSERT_EQ(val->dilateH, 1);
ASSERT_EQ(val->dilateW, 1);
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
ASSERT_EQ(val->padUp, 1);
ASSERT_EQ(val->padDown, 1);
ASSERT_EQ(val->padLeft, 1);
ASSERT_EQ(val->padRight, 1);
}

} // namespace mindspore

+ 58
- 0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc View File

@@ -0,0 +1,58 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"

namespace mindspore {
class TestTfliteParserDeConv : public TestTfliteParser {
public:
TestTfliteParserDeConv() = default;
void SetUp() override { meta_graph = LoadAndConvert("./deconv.tflite", ""); }
};

TEST_F(TestTfliteParserDeConv, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DeConv2D) << "wrong Op Type";
}

TEST_F(TestTfliteParserDeConv, AttrValue) {
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(), nullptr);
auto val = meta_graph->nodes.at(1)->primitive->value.AsDeConv2D();
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->group, 1);
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
ASSERT_EQ(val->hasBias, true);

ASSERT_EQ(val->channelIn, 1);
ASSERT_EQ(val->channelOut, 4);
ASSERT_EQ(val->kernelH, 3);
ASSERT_EQ(val->kernelW, 3);
ASSERT_EQ(val->strideH, 1);
ASSERT_EQ(val->strideW, 1);
ASSERT_EQ(val->dilateH, 1);
ASSERT_EQ(val->dilateW, 1);
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
ASSERT_EQ(val->padUp, 1);
ASSERT_EQ(val->padDown, 1);
ASSERT_EQ(val->padLeft, 1);
ASSERT_EQ(val->padRight, 1);
}

} // namespace mindspore

+ 3
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depth_to_space_parser_test.cc View File

@@ -35,11 +35,9 @@ TEST_F(TestTfliteParserDepthToSpace, OpType) {
} }


TEST_F(TestTfliteParserDepthToSpace, AttrValue) { TEST_F(TestTfliteParserDepthToSpace, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthToSpace(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthToSpace(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->blockSize, 4);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->format, schema::Format_NHWC);
auto val = meta_graph->nodes.front()->primitive->value.AsDepthToSpace();
ASSERT_EQ(val->blockSize, 4);
ASSERT_EQ(val->format, schema::Format_NHWC);
} }
} // namespace mindspore } // namespace mindspore

+ 92
- 0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc View File

@@ -0,0 +1,92 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"

namespace mindspore {
class TestTfliteParserDepthwiseConv1 : public TestTfliteParser {
public:
TestTfliteParserDepthwiseConv1() = default;
void SetUp() override { meta_graph = LoadAndConvert("./depthwise_conv1.tflite", ""); }
};

TEST_F(TestTfliteParserDepthwiseConv1, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
}

TEST_F(TestTfliteParserDepthwiseConv1, AttrValue) {
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr);
auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D();
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->group, 0);
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
ASSERT_EQ(val->hasBias, true);
ASSERT_EQ(val->channelIn, 1);
ASSERT_EQ(val->channelOut, 4);
ASSERT_EQ(val->kernelH, 3);
ASSERT_EQ(val->kernelW, 3);
ASSERT_EQ(val->strideH, 1);
ASSERT_EQ(val->strideW, 1);
ASSERT_EQ(val->dilateH, 1);
ASSERT_EQ(val->dilateW, 1);
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
ASSERT_EQ(val->padUp, 1);
ASSERT_EQ(val->padDown, 1);
ASSERT_EQ(val->padLeft, 1);
ASSERT_EQ(val->padRight, 1);
}

class TestTfliteParserDepthwiseConv2 : public TestTfliteParser {
public:
TestTfliteParserDepthwiseConv2() = default;
void SetUp() override { meta_graph = LoadAndConvert("./depthwise_conv2.tflite", ""); }
};

TEST_F(TestTfliteParserDepthwiseConv2, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DepthwiseConv2D) << "wrong Op Type";
}

TEST_F(TestTfliteParserDepthwiseConv2, AttrValue) {
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(), nullptr);
auto val = meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D();
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
ASSERT_EQ(val->hasBias, true);
ASSERT_EQ(val->channelIn, 2);
ASSERT_EQ(val->channelMultiplier, 1);
ASSERT_EQ(val->kernelH, 3);
ASSERT_EQ(val->kernelW, 3);
ASSERT_EQ(val->strideH, 1);
ASSERT_EQ(val->strideW, 1);
ASSERT_EQ(val->dilateH, 1);
ASSERT_EQ(val->dilateW, 1);
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
ASSERT_EQ(val->padUp, 1);
ASSERT_EQ(val->padDown, 1);
ASSERT_EQ(val->padLeft, 1);
ASSERT_EQ(val->padRight, 1);
}

} // namespace mindspore

+ 3
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_fill_parser_test.cc View File

@@ -25,17 +25,15 @@ class TestTfliteParserFill : public TestTfliteParser {
}; };


TEST_F(TestTfliteParserFill, OpType) { TEST_F(TestTfliteParserFill, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Fill) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Fill) << "wrong Op Type";
} }


TEST_F(TestTfliteParserFill, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

TEST_F(TestTfliteParserFill, AttrValue) {;
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsFill(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsFill(); auto val = meta_graph->nodes.front()->primitive->value.AsFill();

std::vector<int32_t> dims = {9}; std::vector<int32_t> dims = {9};
ASSERT_EQ(val->dims, dims); ASSERT_EQ(val->dims, dims);
} }


+ 2
- 3
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_nd_parser_test.cc View File

@@ -25,15 +25,14 @@ class TestTfliteParserGatherNd : public TestTfliteParser {
}; };


TEST_F(TestTfliteParserGatherNd, OpType) { TEST_F(TestTfliteParserGatherNd, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_GatherNd) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_GatherNd) << "wrong Op Type";
} }


TEST_F(TestTfliteParserGatherNd, AttrValue) { TEST_F(TestTfliteParserGatherNd, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGatherNd(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsGatherNd(); auto val = meta_graph->nodes.front()->primitive->value.AsGatherNd();
ASSERT_EQ(val->batchDims, 0); ASSERT_EQ(val->batchDims, 0);
} }


+ 2
- 3
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_parser_test.cc View File

@@ -25,15 +25,14 @@ class TestTfliteParserGather : public TestTfliteParser {
}; };


TEST_F(TestTfliteParserGather, OpType) { TEST_F(TestTfliteParserGather, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Gather) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Gather) << "wrong Op Type";
} }


TEST_F(TestTfliteParserGather, AttrValue) { TEST_F(TestTfliteParserGather, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGather(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsGather(); auto val = meta_graph->nodes.front()->primitive->value.AsGather();
ASSERT_EQ(val->axis, 0); ASSERT_EQ(val->axis, 0);
ASSERT_EQ(val->batchDims, 0); ASSERT_EQ(val->batchDims, 0);


+ 2
- 3
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_lrn_parser_test.cc View File

@@ -25,6 +25,7 @@ class TestTfliteParserLRN : public TestTfliteParser {
}; };


TEST_F(TestTfliteParserLRN, OpType) { TEST_F(TestTfliteParserLRN, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type,
@@ -32,9 +33,7 @@ TEST_F(TestTfliteParserLRN, OpType) {
} }


TEST_F(TestTfliteParserLRN, AttrValue) { TEST_F(TestTfliteParserLRN, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization(); auto val = meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization();
ASSERT_EQ(val->alpha, 1); ASSERT_EQ(val->alpha, 1);
ASSERT_EQ(val->beta, 0.5); ASSERT_EQ(val->beta, 0.5);


+ 2
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_one_hot_parser_test.cc View File

@@ -32,12 +32,9 @@ TEST_F(TestTfliteParserOneHot, OpType) {
} }


TEST_F(TestTfliteParserOneHot, AttrValue) { TEST_F(TestTfliteParserOneHot, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsOneHot(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsOneHot(), nullptr);
// in OneHot parser axis = axis > 0 ? axis : axis + tensor_shape.size()
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsOneHot()->axis, 2);
auto val = meta_graph->nodes.front()->primitive->value.AsOneHot();
ASSERT_EQ(val->axis, 2);
} }


} // namespace mindspore } // namespace mindspore

+ 2
- 4
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pad_parser_test.cc View File

@@ -25,17 +25,15 @@ class TestTfliteParserPad : public TestTfliteParser {
}; };


TEST_F(TestTfliteParserPad, OpType) { TEST_F(TestTfliteParserPad, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Pad) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Pad) << "wrong Op Type";
} }


TEST_F(TestTfliteParserPad, AttrValue) { TEST_F(TestTfliteParserPad, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPad(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsPad(); auto val = meta_graph->nodes.front()->primitive->value.AsPad();

std::vector<int32_t> paddings = {1, 1, 2, 2, 3, 3, 4, 4}; std::vector<int32_t> paddings = {1, 1, 2, 2, 3, 3, 4, 4};
ASSERT_EQ(val->paddings, paddings); ASSERT_EQ(val->paddings, paddings);
} }


+ 2
- 10
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pooling_parser_test.cc View File

@@ -35,12 +35,8 @@ TEST_F(TestTfliteParserMaxPooling, OpType) {
} }


TEST_F(TestTfliteParserMaxPooling, AttrValue) { TEST_F(TestTfliteParserMaxPooling, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPooling(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsPooling(); auto val = meta_graph->nodes.front()->primitive->value.AsPooling();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->poolingMode, schema::PoolMode_MAX_POOLING); ASSERT_EQ(val->poolingMode, schema::PoolMode_MAX_POOLING);
ASSERT_EQ(val->global, false); ASSERT_EQ(val->global, false);
@@ -72,12 +68,8 @@ TEST_F(TestTfliteParserAvgPooling, OpType) {
} }


TEST_F(TestTfliteParserAvgPooling, AttrValue) { TEST_F(TestTfliteParserAvgPooling, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPooling(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsPooling(); auto val = meta_graph->nodes.front()->primitive->value.AsPooling();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->poolingMode, schema::PoolMode_MEAN_POOLING); ASSERT_EQ(val->poolingMode, schema::PoolMode_MEAN_POOLING);
ASSERT_EQ(val->global, false); ASSERT_EQ(val->global, false);


+ 10
- 30
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_parser_test.cc View File

@@ -32,13 +32,9 @@ TEST_F(TestTfliteParserReduceMax, OpType) {
} }


TEST_F(TestTfliteParserReduceMax, AttrValue) { TEST_F(TestTfliteParserReduceMax, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMax) << "wrong reduce mode";
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMax);
ASSERT_EQ(val->keepDims, false); ASSERT_EQ(val->keepDims, false);
std::vector<int32_t> axes = {2}; std::vector<int32_t> axes = {2};
ASSERT_EQ(val->axes, axes); ASSERT_EQ(val->axes, axes);
@@ -58,13 +54,9 @@ TEST_F(TestTfliteParserReduceMin, OpType) {
} }


TEST_F(TestTfliteParserReduceMin, AttrValue) { TEST_F(TestTfliteParserReduceMin, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMin) << "wrong reduce mode";
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMin);
ASSERT_EQ(val->keepDims, false); ASSERT_EQ(val->keepDims, false);
std::vector<int32_t> axes = {2}; std::vector<int32_t> axes = {2};
ASSERT_EQ(val->axes, axes); ASSERT_EQ(val->axes, axes);
@@ -84,13 +76,9 @@ TEST_F(TestTfliteParserReduceProd, OpType) {
} }


TEST_F(TestTfliteParserReduceProd, AttrValue) { TEST_F(TestTfliteParserReduceProd, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceProd) << "wrong reduce mode";
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceProd);
ASSERT_EQ(val->keepDims, false); ASSERT_EQ(val->keepDims, false);
std::vector<int32_t> axes = {2}; std::vector<int32_t> axes = {2};
ASSERT_EQ(val->axes, axes); ASSERT_EQ(val->axes, axes);
@@ -111,13 +99,9 @@ TEST_F(TestTfliteParserSum, OpType) {
} }


TEST_F(TestTfliteParserSum, AttrValue) { TEST_F(TestTfliteParserSum, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceSum) << "wrong reduce mode";
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceSum);
ASSERT_EQ(val->keepDims, false); ASSERT_EQ(val->keepDims, false);
std::vector<int32_t> axes = {2}; std::vector<int32_t> axes = {2};
ASSERT_EQ(val->axes, axes); ASSERT_EQ(val->axes, axes);
@@ -138,13 +122,9 @@ TEST_F(TestTfliteParserMean, OpType) {
} }


TEST_F(TestTfliteParserMean, AttrValue) { TEST_F(TestTfliteParserMean, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMean) << "wrong reduce mode";
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMean);
ASSERT_EQ(val->keepDims, true); ASSERT_EQ(val->keepDims, true);
std::vector<int32_t> axes = {2, 3}; std::vector<int32_t> axes = {2, 3};
ASSERT_EQ(val->axes, axes); ASSERT_EQ(val->axes, axes);


+ 2
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reshape_parser_test.cc View File

@@ -35,12 +35,9 @@ TEST_F(TestTfliteParserReshape, OpType) {
} }


TEST_F(TestTfliteParserReshape, AttrValue) { TEST_F(TestTfliteParserReshape, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReshape(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReshape(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReshape();
std::vector<int64_t> shape = {3, 5, 20}; std::vector<int64_t> shape = {3, 5, 20};
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReshape()->shape, shape); // int32
ASSERT_EQ(val->shape, shape);
} }
} // namespace mindspore } // namespace mindspore

+ 4
- 8
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_resize_parser_test.cc View File

@@ -26,17 +26,15 @@ class TestTfliteParserResizeNN : public TestTfliteParser {
}; };


TEST_F(TestTfliteParserResizeNN, OpType) { TEST_F(TestTfliteParserResizeNN, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type";
} }


TEST_F(TestTfliteParserResizeNN, AttrValue) { TEST_F(TestTfliteParserResizeNN, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsResize(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsResize(); auto val = meta_graph->nodes.front()->primitive->value.AsResize();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->alignCorners, false); ASSERT_EQ(val->alignCorners, false);
ASSERT_EQ(val->newHeight, 3); ASSERT_EQ(val->newHeight, 3);
ASSERT_EQ(val->newWidth, 100); ASSERT_EQ(val->newWidth, 100);
@@ -52,17 +50,15 @@ class TestTfliteParserResizeBilinear : public TestTfliteParser {
}; };


TEST_F(TestTfliteParserResizeBilinear, OpType) { TEST_F(TestTfliteParserResizeBilinear, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type";
} }


TEST_F(TestTfliteParserResizeBilinear, AttrValue) { TEST_F(TestTfliteParserResizeBilinear, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsResize(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsResize(); auto val = meta_graph->nodes.front()->primitive->value.AsResize();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->alignCorners, false); ASSERT_EQ(val->alignCorners, false);
ASSERT_EQ(val->newHeight, 75); ASSERT_EQ(val->newHeight, 75);
ASSERT_EQ(val->newWidth, 4); ASSERT_EQ(val->newWidth, 4);


+ 2
- 4
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_parser_test.cc View File

@@ -25,17 +25,15 @@ class TestTfliteParserReverse : public TestTfliteParser {
}; };


TEST_F(TestTfliteParserReverse, OpType) { TEST_F(TestTfliteParserReverse, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reverse) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reverse) << "wrong Op Type";
} }


TEST_F(TestTfliteParserReverse, AttrValue) { TEST_F(TestTfliteParserReverse, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);

ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverse(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReverse(); auto val = meta_graph->nodes.front()->primitive->value.AsReverse();

std::vector<int32_t> axis = {3}; std::vector<int32_t> axis = {3};
ASSERT_EQ(val->axis, axis); ASSERT_EQ(val->axis, axis);
} }


+ 5
- 7
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_sequence_parser_test.cc View File

@@ -35,13 +35,11 @@ TEST_F(TestTfliteParserReverseSequence, OpType) {
} }


TEST_F(TestTfliteParserReverseSequence, AttrValue) { TEST_F(TestTfliteParserReverseSequence, AttrValue) {
std::vector<int> seq_length{7, 2, 3, 5};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverseSequence(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverseSequence(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqLengths, seq_length);
auto val = meta_graph->nodes.front()->primitive->value.AsReverseSequence();
ASSERT_EQ(val->seqAxis, 1);
ASSERT_EQ(val->seqAxis, 1);
std::vector<int> seq_length = {7, 2, 3, 5};
ASSERT_EQ(val->seqLengths, seq_length);
} }
} // namespace mindspore } // namespace mindspore

+ 45
- 0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_slice_parser_test.cc View File

@@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"

namespace mindspore {
class TestTfliteParserSlice : public TestTfliteParser {
public:
TestTfliteParserSlice() = default;

void SetUp() override { meta_graph = LoadAndConvert("./slice.tflite"); }
};

TEST_F(TestTfliteParserSlice, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Slice) << "wrong Op Type";
}

TEST_F(TestTfliteParserSlice, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSlice(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsSlice();
ASSERT_EQ(val->format, schema::Format_NHWC);
std::vector<int32_t> begin = {1, 0, 0};
ASSERT_EQ(val->begin, begin);
std::vector<int32_t> size = {1, 1, 3};
ASSERT_EQ(val->size, size);
}

} // namespace mindspore

+ 3
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_softmax_parser_test.cc View File

@@ -35,11 +35,9 @@ TEST_F(TestTfliteParserSoftmax, OpType) {
} }


TEST_F(TestTfliteParserSoftmax, AttrValue) { TEST_F(TestTfliteParserSoftmax, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSoftMax(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSoftMax()->axis, -1);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSoftMax(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsSoftMax();
ASSERT_EQ(val->axis, -1);
} }


} // namespace mindspore } // namespace mindspore

+ 5
- 7
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser_test.cc View File

@@ -35,13 +35,11 @@ TEST_F(TestTfliteParserSpaceToBatchND, OpType) {
} }


TEST_F(TestTfliteParserSpaceToBatchND, AttrValue) { TEST_F(TestTfliteParserSpaceToBatchND, AttrValue) {
std::vector<int> blockshape{2, 2};
std::vector<int> padding{0, 0, 2, 0};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->blockShape, blockshape);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->paddings, padding);
auto val = meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND();
std::vector<int> blockshape = {2, 2};
ASSERT_EQ(val->blockShape, blockshape);
std::vector<int> padding = {0, 0, 2, 0};
ASSERT_EQ(val->paddings, padding);
} }
} // namespace mindspore } // namespace mindspore

+ 3
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_depth_parser_test.cc View File

@@ -35,11 +35,9 @@ TEST_F(TestTfliteParserSpaceToDepth, OpType) {
} }


TEST_F(TestTfliteParserSpaceToDepth, AttrValue) { TEST_F(TestTfliteParserSpaceToDepth, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->blockSize, 2);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->format, schema::Format_NHWC);
auto val = meta_graph->nodes.front()->primitive->value.AsSpaceToDepth();
ASSERT_EQ(val->blockSize, 2);
ASSERT_EQ(val->format, schema::Format_NHWC);
} }
} // namespace mindspore } // namespace mindspore

+ 8
- 10
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sparse_to_dense_parser_test.cc View File

@@ -35,16 +35,14 @@ TEST_F(TestTfliteParserSparseToDense, OpType) {
} }


TEST_F(TestTfliteParserSparseToDense, AttrValue) { TEST_F(TestTfliteParserSparseToDense, AttrValue) {
std::vector<int> outputShape{5, 5};
std::vector<int> sparseValue{1};
std::vector<int> defaultValue{0};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSparseToDense(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSparseToDense(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->outputShape, outputShape);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->sparseValue, sparseValue);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->defaultValue, defaultValue);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->validateIndices, false);
auto val = meta_graph->nodes.front()->primitive->value.AsSparseToDense();
std::vector<int> outputShape = {5, 5};
ASSERT_EQ(val->outputShape, outputShape);
std::vector<int> sparseValue = {1};
ASSERT_EQ(val->sparseValue, sparseValue);
std::vector<int> defaultValue = {0};
ASSERT_EQ(val->defaultValue, defaultValue);
ASSERT_EQ(val->validateIndices, false);
} }
} // namespace mindspore } // namespace mindspore

+ 5
- 7
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_parser_test.cc View File

@@ -33,14 +33,12 @@ TEST_F(TestTfliteParserSplit, OpType) {
} }


TEST_F(TestTfliteParserSplit, AttrValue) { TEST_F(TestTfliteParserSplit, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr);
const std::vector<int> sizeSplits{2, 2};
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 2);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits);
auto val = meta_graph->nodes.front()->primitive->value.AsSplit();
ASSERT_EQ(val->splitDim, 2);
ASSERT_EQ(val->numberSplit, 2);
const std::vector<int> sizeSplits = {2, 2};
ASSERT_EQ(val->sizeSplits, sizeSplits);
} }


} // namespace mindspore } // namespace mindspore

+ 5
- 7
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_v_parser_test.cc View File

@@ -33,14 +33,12 @@ TEST_F(TestTfliteParserSplitV, OpType) {
} }


TEST_F(TestTfliteParserSplitV, AttrValue) { TEST_F(TestTfliteParserSplitV, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr);
const std::vector<int> sizeSplits{1, 3};
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits);
auto val = meta_graph->nodes.front()->primitive->value.AsSplit();
ASSERT_EQ(val->splitDim, 0);
ASSERT_EQ(val->numberSplit, 2);
const std::vector<int> sizeSplits = {1, 3};
ASSERT_EQ(val->sizeSplits, sizeSplits);
} }


} // namespace mindspore } // namespace mindspore

+ 44
- 0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_stack_parser_test.cc View File

@@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"

namespace mindspore {
class TestTfliteParserStack : public TestTfliteParser {
public:
TestTfliteParserStack() = default;

void SetUp() override { meta_graph = LoadAndConvert("./stack.tflite"); }
};

TEST_F(TestTfliteParserStack, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Stack) << "wrong Op Type";
}

TEST_F(TestTfliteParserStack, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStack(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsStack();
ASSERT_EQ(val->axis, 1);
ASSERT_EQ(val->n, 2);
const std::vector<int> isScale = {3, 2, 3};
ASSERT_EQ(val->isScale, isScale);
}

} // namespace mindspore

+ 13
- 15
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_strided_slice_parser_test.cc View File

@@ -35,21 +35,19 @@ TEST_F(TestTfliteParserStridedSlice, OpType) {
} }


TEST_F(TestTfliteParserStridedSlice, AttrValue) { TEST_F(TestTfliteParserStridedSlice, AttrValue) {
std::vector<int> begin{1, -1, 0};
std::vector<int> end{2, -3, 3};
std::vector<int> stride{1, -1, 1};
std::vector<int> isscale{3, 2, 3};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStridedSlice(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStridedSlice(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->endMask, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->begin, begin);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->end, end);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->stride, stride);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->isScale, isscale);
auto val = meta_graph->nodes.front()->primitive->value.AsStridedSlice();
ASSERT_EQ(val->beginMask, 0);
ASSERT_EQ(val->endMask, 0);
ASSERT_EQ(val->beginMask, 0);
ASSERT_EQ(val->beginMask, 0);
std::vector<int> begin = {1, -1, 0};
ASSERT_EQ(val->begin, begin);
std::vector<int> end = {2, -3, 3};
ASSERT_EQ(val->end, end);
std::vector<int> stride = {1, -1, 1};
ASSERT_EQ(val->stride, stride);
std::vector<int> isscale = {3, 2, 3};
ASSERT_EQ(val->isScale, isscale);
} }
} // namespace mindspore } // namespace mindspore

+ 3
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_tile_parser_test.cc View File

@@ -35,11 +35,9 @@ TEST_F(TestTfliteParserTile, OpType) {
} }


TEST_F(TestTfliteParserTile, AttrValue) { TEST_F(TestTfliteParserTile, AttrValue) {
std::vector<int> multiply{2, 3, 4};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTile(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTile(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTile()->multiples, multiply);
auto val = meta_graph->nodes.front()->primitive->value.AsTile();
std::vector<int> multiply = {2, 3, 4};
ASSERT_EQ(val->multiples, multiply);
} }
} // namespace mindspore } // namespace mindspore

+ 4
- 7
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc View File

@@ -35,13 +35,10 @@ TEST_F(TestTfliteParserTopKV2, OpType) {
} }


TEST_F(TestTfliteParserTopKV2, AttrValue) { TEST_F(TestTfliteParserTopKV2, AttrValue) {
// attr->sorted default is true
std::vector<int> k{3};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->k, k);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->sorted, true);
auto val = meta_graph->nodes.front()->primitive->value.AsTopKV2();
std::vector<int> k = {3};
ASSERT_EQ(val->k, k);
ASSERT_EQ(val->sorted, true);
} }
} // namespace mindspore } // namespace mindspore

+ 43
- 0
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_transpose_parser_test.cc View File

@@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"

namespace mindspore {
class TestTfliteParserTranspose : public TestTfliteParser {
public:
TestTfliteParserTranspose() = default;

void SetUp() override { meta_graph = LoadAndConvert("./transpose.tflite"); }
};

TEST_F(TestTfliteParserTranspose, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
}

TEST_F(TestTfliteParserTranspose, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTranspose(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsTranspose();
ASSERT_EQ(val->conjugate, false);
std::vector<int32_t> perm = {1, 0};
ASSERT_EQ(val->perm, perm);
}

} // namespace mindspore

+ 3
- 4
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unique_parser_test.cc View File

@@ -35,10 +35,9 @@ TEST_F(TestTfliteParserUnique, OpType) {
} }


TEST_F(TestTfliteParserUnique, AttrValue) { TEST_F(TestTfliteParserUnique, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnique()->outType, 34); // int32
auto val = meta_graph->nodes.front()->primitive->value.AsUnique();
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr);
ASSERT_EQ(val->outType, 34);
} }
} // namespace mindspore } // namespace mindspore

+ 3
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unstack_parser_test.cc View File

@@ -35,11 +35,9 @@ TEST_F(TestTfliteParserUnstack, OpType) {
} }


TEST_F(TestTfliteParserUnstack, AttrValue) { TEST_F(TestTfliteParserUnstack, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnstack(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnstack(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->num, 5);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->axis, 1);
auto val = meta_graph->nodes.front()->primitive->value.AsUnstack();
ASSERT_EQ(val->num, 5);
ASSERT_EQ(val->axis, 1);
} }
} // namespace mindspore } // namespace mindspore

+ 2
- 2
mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc View File

@@ -353,7 +353,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC); status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
} else if (weightTensor->format == schema::Format_KCHW) { } else if (weightTensor->format == schema::Format_KCHW) {
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_CHWK) {
} else if (weightTensor->format == schema::Format_CHWK) { // from tflite
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
} else { } else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
@@ -369,7 +369,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_CHWK) { // from tf
} else if (weightTensor->format == schema::Format_CHWK) { // from tflite
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
} else { } else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;


+ 8
- 3
mindspore/lite/tools/converter/parser/caffe/caffe_flatten_parser.cc View File

@@ -21,11 +21,16 @@ namespace lite {
STATUS CaffeFlattenParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, STATUS CaffeFlattenParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight,
schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) {
if (op == nullptr) { if (op == nullptr) {
// MS_LOGE("null pointer dereferencing.");
// MS_LOG(ERROR) << "null pointer dereferencing.";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
std::unique_ptr<schema::ReshapeT> attr(new schema::ReshapeT());
attr->format = schema::Format_NCHW;
std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT());
const caffe::FlattenParameter flattenParam = proto.flatten_param();

attr->axis = (int32_t)flattenParam.axis();
attr->useAxis = true;
attr->hasBias = false;
attr->activationType = schema::ActivationType_NO_ACTIVATION;


op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Flatten; op->primitive->value.type = schema::PrimitiveType_Flatten;


+ 68
- 24
mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc View File

@@ -14,18 +14,21 @@
* limitations under the License. * limitations under the License.
*/ */


#include "tools/converter/parser/tflite/tflite_activation_parser.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <string> #include <string>
#include "tools/converter/parser/tflite/tflite_activation_parser.h"
#include <map>


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -35,13 +38,11 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
MS_LOG(ERROR) << "op->primitive is null"; MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }

std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT()); std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT());


std::vector<std::string> node_name_str; std::vector<std::string> node_name_str;
Split(op->name, &node_name_str, "-"); Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str(); const char *node_name = node_name_str.data()->c_str();

if (std::strcmp(node_name, "Relu") == 0) { if (std::strcmp(node_name, "Relu") == 0) {
MS_LOG(DEBUG) << "parse TfliteReluParser"; MS_LOG(DEBUG) << "parse TfliteReluParser";
attr->type = schema::ActivationType_RELU; attr->type = schema::ActivationType_RELU;
@@ -54,29 +55,31 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
} else if (std::strcmp(node_name, "Logistic") == 0) { } else if (std::strcmp(node_name, "Logistic") == 0) {
MS_LOG(DEBUG) << "parse TfliteLogisticParser"; MS_LOG(DEBUG) << "parse TfliteLogisticParser";
attr->type = schema::ActivationType_SIGMOID; attr->type = schema::ActivationType_SIGMOID;
} else if (std::strcmp(node_name, "LeakyRelu") == 0) {
const auto &option = tfliteOp->builtin_options.AsLeakyReluOptions();
if (option == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->type = schema::ActivationType_LEAKY_RELU;
attr->alpha = option->alpha;
} else {
MS_LOG(ERROR) << "wrong activation type";
return RET_ERROR;
} else if (std::strcmp(node_name, "HardSwish") == 0) {
MS_LOG(DEBUG) << "parse TfliteHardSwishParser";
attr->type = schema::ActivationType_SIGMOID;
} }


attr->alpha = 0.2f;
op->primitive->value.type = schema::PrimitiveType_Activation; op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }


STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) {
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TflitePreluParser";

if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -86,23 +89,64 @@ STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite
MS_LOG(ERROR) << "op->primitive is null"; MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }

MS_LOG(DEBUG) << "paser TflitePreluParser";
std::unique_ptr<schema::PreluT> attr(new schema::PreluT()); std::unique_ptr<schema::PreluT> attr(new schema::PreluT());


if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) { if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) {
MS_LOG(ERROR) << "get pRelu -> slope failed"; MS_LOG(ERROR) << "get pRelu -> slope failed";
return RET_ERROR; return RET_ERROR;
} }

op->primitive->value.type = schema::PrimitiveType_Prelu; op->primitive->value.type = schema::PrimitiveType_Prelu;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}

STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteLeakyReluParser";

if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}

std::unique_ptr<schema::LeakyReLUT> attr(new schema::LeakyReLUT());

const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->negativeSlope = tflite_attr->alpha;

op->primitive->value.type = schema::PrimitiveType_LeakyReLU;
op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }


TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser()); TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser());
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser()); TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser());
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser());
TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser());
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser());
TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser()); TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser());
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser()); TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser());


+ 32
- 14
mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h View File

@@ -14,13 +14,14 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef PREDICT_TFLITE_RELU_PARSER_H
#define PREDICT_TFLITE_RELU_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H


#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@@ -29,11 +30,13 @@ class TfliteActivationParser : public TfliteNodeParser {
public: public:
TfliteActivationParser() : TfliteNodeParser("node_name") {} TfliteActivationParser() : TfliteNodeParser("node_name") {}


STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };


class TfliteReluParser : public TfliteActivationParser { class TfliteReluParser : public TfliteActivationParser {
@@ -56,9 +59,9 @@ class TfliteLogisticParser : public TfliteActivationParser {
TfliteLogisticParser() : TfliteActivationParser() {} TfliteLogisticParser() : TfliteActivationParser() {}
}; };


class TfliteLeakyReluParser : public TfliteActivationParser {
class TfliteHardSwishParser : public TfliteActivationParser {
public: public:
TfliteLeakyReluParser() : TfliteActivationParser() {}
TfliteHardSwishParser() : TfliteActivationParser() {}
}; };


class TflitePreluParser : public TfliteNodeParser { class TflitePreluParser : public TfliteNodeParser {
@@ -68,12 +71,27 @@ class TflitePreluParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantized_model) override;
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};

class TfliteLeakyReluParser : public TfliteNodeParser {
public:
TfliteLeakyReluParser() : TfliteNodeParser("LeakyRelu") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };


} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // PREDICT_TFLITE_RELU_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H



+ 20
- 8
mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc View File

@@ -18,14 +18,20 @@
#include "tools/converter/parser/tflite/tflite_addn_parser.h" #include "tools/converter/parser/tflite/tflite_addn_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteAddNParser";

// set attr
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -36,13 +42,19 @@ STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR; return RET_NULL_PTR;
} }


MS_LOG(DEBUG) << "parse TfliteAddNParser";
std::unique_ptr<schema::AddNT> attr(new schema::AddNT()); std::unique_ptr<schema::AddNT> attr(new schema::AddNT());


attr->N = tfliteTensors.size() - 1;

attr->N = tflite_tensors.size() - 1;
op->primitive->value.type = schema::PrimitiveType_AddN; op->primitive->value.type = schema::PrimitiveType_AddN;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();

// set input
for (int i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }




+ 8
- 6
mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef LITE_TFLITE_ADDN_PARSER_H
#define LITE_TFLITE_ADDN_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H


#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"


@@ -31,11 +32,12 @@ class TfliteAddNParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // LITE_TFLITE_ADDN_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H

+ 18
- 10
mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc View File

@@ -17,16 +17,19 @@
#include "tools/converter/parser/tflite/tflite_argmax_parser.h" #include "tools/converter/parser/tflite/tflite_argmax_parser.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteArgmaxParser";

if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -37,7 +40,6 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR; return RET_NULL_PTR;
} }


MS_LOG(DEBUG) << "parse TfliteArgmaxParser";
std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT()); std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT());


attr->outMaxValue = false; attr->outMaxValue = false;
@@ -45,9 +47,10 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
attr->keepDims = false; attr->keepDims = false;
attr->axisType = 1; attr->axisType = 1;


auto axis_idx = tfliteOp->inputs[1];
std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){});
auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer];
// get axis attr
auto axis_idx = tflite_op->inputs[1];
std::for_each(tflite_tensors[axis_idx]->shape.begin(), tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha){});
auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer];
if (buf_data == nullptr) { if (buf_data == nullptr) {
MS_LOG(ERROR) << "the buf data is null"; MS_LOG(ERROR) << "the buf data is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -61,6 +64,11 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit


op->primitive->value.type = schema::PrimitiveType_ArgMax; op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }




+ 11
- 9
mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef PREDICT_TFLITE_ARGMAX_PARSER_H
#define PREDICT_TFLITE_ARGMAX_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H


#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"


@@ -28,14 +29,15 @@ class TfliteArgmaxParser : public TfliteNodeParser {
public: public:
TfliteArgmaxParser() : TfliteNodeParser("Argmax") {} TfliteArgmaxParser() : TfliteNodeParser("Argmax") {}


STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // PREDICT_TFLITE_ARGMAX_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H

+ 20
- 9
mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc View File

@@ -17,14 +17,19 @@
#include "tools/converter/parser/tflite/tflite_argmin_parser.h" #include "tools/converter/parser/tflite/tflite_argmin_parser.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteArgminParser";

if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -35,7 +40,6 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR; return RET_NULL_PTR;
} }


MS_LOG(DEBUG) << "parse TfliteArgminParser";
std::unique_ptr<schema::ArgMinT> attr(new schema::ArgMinT()); std::unique_ptr<schema::ArgMinT> attr(new schema::ArgMinT());


attr->outMaxValue = false; attr->outMaxValue = false;
@@ -43,9 +47,11 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
attr->keepDims = false; attr->keepDims = false;
attr->axisType = 1; attr->axisType = 1;


auto axis_idx = tfliteOp->inputs[1];
std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){});
auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer];
// get axis attr
auto axis_idx = tflite_op->inputs[1];
std::for_each(tflite_tensors[axis_idx]->shape.begin(),
tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha){});
auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer];
if (buf_data == nullptr) { if (buf_data == nullptr) {
MS_LOG(ERROR) << "the buf data is null"; MS_LOG(ERROR) << "the buf data is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -59,6 +65,11 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit


op->primitive->value.type = schema::PrimitiveType_ArgMin; op->primitive->value.type = schema::PrimitiveType_ArgMin;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }




+ 11
- 9
mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef PREDICT_TFLITE_ARGMIN_PARSER_H
#define PREDICT_TFLITE_ARGMIN_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H


#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"


@@ -28,14 +29,15 @@ class TfliteArgminParser : public TfliteNodeParser {
public: public:
TfliteArgminParser() : TfliteNodeParser("Argmin") {} TfliteArgminParser() : TfliteNodeParser("Argmin") {}


STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // PREDICT_TFLITE_ARGMIN_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H

+ 84
- 136
mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc View File

@@ -18,14 +18,17 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <string> #include <string>
#include <map>


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -37,124 +40,72 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
} }


std::vector<std::string> node_name_str; std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-");
Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str(); const char *node_name = node_name_str.data()->c_str();

if (std::strcmp(node_name, "Add") == 0
|| std::strcmp(node_name, "Sub") == 0
|| std::strcmp(node_name, "Mul") == 0
|| std::strcmp(node_name, "Div") == 0) {
auto x_index = tfliteOp->inputs[0];
const auto &x_tensor = tfliteTensors[x_index];
if (x_tensor == nullptr) {
MS_LOG(ERROR) << "the first input is null";
if (std::strcmp(node_name, "Add") == 0) {
MS_LOG(DEBUG) << "parse TfliteAddParser";
std::unique_ptr<schema::AddT> attr(new schema::AddT());
const auto &tfliteAttr = tflite_op->builtin_options.AsAddOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
auto &x_data = tfliteModelBuffer.at(x_tensor->buffer);
if (x_data == nullptr) {
MS_LOG(ERROR) << "the data of the first input is null";
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Add;
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Sub") == 0) {
MS_LOG(DEBUG) << "parse TfliteSubParser";
std::unique_ptr<schema::SubT> attr(new schema::SubT());
const auto &tfliteAttr = tflite_op->builtin_options.AsSubOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
if (!x_data->data.empty()) {
std::vector<tflite::TensorT *> x_tensors{x_tensor.get()};
if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse the first tensor failed";
return RET_ERROR;
}
}

auto y_index = tfliteOp->inputs[1];
const auto &y_tensor = tfliteTensors[y_index];
if (y_tensor == nullptr) {
MS_LOG(ERROR) << "the second input is null";
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Sub;
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Mul") == 0) {
MS_LOG(DEBUG) << "parse TfliteMulParser";
std::unique_ptr<schema::MulT> attr(new schema::MulT());
const auto &tfliteAttr = tflite_op->builtin_options.AsMulOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
auto &y_data = tfliteModelBuffer.at(y_tensor->buffer);
if (y_data == nullptr) {
MS_LOG(ERROR) << "the data of the second input is null";
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Mul;
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Div") == 0) {
MS_LOG(DEBUG) << "parse TfliteDivParser";
std::unique_ptr<schema::DivT> attr(new schema::DivT());
const auto &tfliteAttr = tflite_op->builtin_options.AsDivOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
if (!y_data->data.empty()) {
std::vector<tflite::TensorT *> y_tensors{y_tensor.get()};
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse the second tensor failed";
return RET_ERROR;
}
}

if (std::strcmp(node_name, "Add") == 0) {
MS_LOG(DEBUG) << "parse TfliteAddParser";
std::unique_ptr<schema::AddT> attr(new schema::AddT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsAddOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Add;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Sub") == 0) {
MS_LOG(DEBUG) << "parse TfliteSubParser";
std::unique_ptr<schema::SubT> attr(new schema::SubT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsSubOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Sub;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Mul") == 0) {
MS_LOG(DEBUG) << "parse TfliteMulParser";
std::unique_ptr<schema::MulT> attr(new schema::MulT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsMulOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Mul;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Div") == 0) {
MS_LOG(DEBUG) << "parse TfliteDivParser";
std::unique_ptr<schema::DivT> attr(new schema::DivT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsDivOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Div;
op->primitive->value.value = attr.release();
return RET_OK;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Div;
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "FloorDiv") == 0) { } else if (std::strcmp(node_name, "FloorDiv") == 0) {
MS_LOG(DEBUG) << "parse TfliteFloorDivParser"; MS_LOG(DEBUG) << "parse TfliteFloorDivParser";
std::unique_ptr<schema::FloorDivT> attr(new schema::FloorDivT()); std::unique_ptr<schema::FloorDivT> attr(new schema::FloorDivT());
op->primitive->value.type = schema::PrimitiveType_FloorDiv; op->primitive->value.type = schema::PrimitiveType_FloorDiv;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "FloorMod") == 0) { } else if (std::strcmp(node_name, "FloorMod") == 0) {
MS_LOG(DEBUG) << "parse TfliteFloorModParser"; MS_LOG(DEBUG) << "parse TfliteFloorModParser";
std::unique_ptr<schema::FloorModT> attr(new schema::FloorModT()); std::unique_ptr<schema::FloorModT> attr(new schema::FloorModT());
op->primitive->value.type = schema::PrimitiveType_FloorMod; op->primitive->value.type = schema::PrimitiveType_FloorMod;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "RealDiv") == 0) { } else if (std::strcmp(node_name, "RealDiv") == 0) {
MS_LOG(DEBUG) << "parse TfliteRealDivParser"; MS_LOG(DEBUG) << "parse TfliteRealDivParser";
std::unique_ptr<schema::RealDivT> attr(new schema::RealDivT()); std::unique_ptr<schema::RealDivT> attr(new schema::RealDivT());
op->primitive->value.type = schema::PrimitiveType_RealDiv;
op->primitive->value.type = schema::PrimitiveType_Div;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "SquaredDifference") == 0) { } else if (std::strcmp(node_name, "SquaredDifference") == 0) {
MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser"; MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser";
std::unique_ptr<schema::SquaredDifferenceT> attr(new schema::SquaredDifferenceT()); std::unique_ptr<schema::SquaredDifferenceT> attr(new schema::SquaredDifferenceT());
op->primitive->value.type = schema::PrimitiveType_SquaredDifference; op->primitive->value.type = schema::PrimitiveType_SquaredDifference;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Pow") == 0) { } else if (std::strcmp(node_name, "Pow") == 0) {
MS_LOG(DEBUG) << "parse TflitePowParser"; MS_LOG(DEBUG) << "parse TflitePowParser";
std::unique_ptr<schema::PowerT> attr(new schema::PowerT()); std::unique_ptr<schema::PowerT> attr(new schema::PowerT());
@@ -163,31 +114,35 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
attr->shift = 0.0f; attr->shift = 0.0f;
op->primitive->value.type = schema::PrimitiveType_Power; op->primitive->value.type = schema::PrimitiveType_Power;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Maximum") == 0) { } else if (std::strcmp(node_name, "Maximum") == 0) {
MS_LOG(DEBUG) << "parse TfliteMaximumParser"; MS_LOG(DEBUG) << "parse TfliteMaximumParser";
std::unique_ptr<schema::MaximumT> attr(new schema::MaximumT()); std::unique_ptr<schema::MaximumT> attr(new schema::MaximumT());
op->primitive->value.type = schema::PrimitiveType_Maximum; op->primitive->value.type = schema::PrimitiveType_Maximum;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Minimum") == 0) { } else if (std::strcmp(node_name, "Minimum") == 0) {
MS_LOG(DEBUG) << "parse TfliteMinimumParser"; MS_LOG(DEBUG) << "parse TfliteMinimumParser";
std::unique_ptr<schema::MinimumT> attr(new schema::MinimumT()); std::unique_ptr<schema::MinimumT> attr(new schema::MinimumT());
op->primitive->value.type = schema::PrimitiveType_Minimum; op->primitive->value.type = schema::PrimitiveType_Minimum;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else {
MS_LOG(ERROR) << "wrong op type";
return RET_ERROR;
} }

// set input
for (int i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }


STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -199,85 +154,79 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
} }


std::vector<std::string> node_name_str; std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-");
Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str(); const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Abs") == 0) { if (std::strcmp(node_name, "Abs") == 0) {
MS_LOG(DEBUG) << "parse TfliteAbsParser"; MS_LOG(DEBUG) << "parse TfliteAbsParser";
std::unique_ptr<schema::AbsT> attr(new schema::AbsT()); std::unique_ptr<schema::AbsT> attr(new schema::AbsT());
op->primitive->value.type = schema::PrimitiveType_Abs; op->primitive->value.type = schema::PrimitiveType_Abs;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Exp") == 0) { } else if (std::strcmp(node_name, "Exp") == 0) {
MS_LOG(DEBUG) << "parse TfliteExpParser"; MS_LOG(DEBUG) << "parse TfliteExpParser";
std::unique_ptr<schema::ExpT> attr(new schema::ExpT()); std::unique_ptr<schema::ExpT> attr(new schema::ExpT());
op->primitive->value.type = schema::PrimitiveType_Exp; op->primitive->value.type = schema::PrimitiveType_Exp;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Sqrt") == 0) { } else if (std::strcmp(node_name, "Sqrt") == 0) {
MS_LOG(DEBUG) << "parse TfliteSqrtParser"; MS_LOG(DEBUG) << "parse TfliteSqrtParser";
std::unique_ptr<schema::SqrtT> attr(new schema::SqrtT()); std::unique_ptr<schema::SqrtT> attr(new schema::SqrtT());
op->primitive->value.type = schema::PrimitiveType_Sqrt; op->primitive->value.type = schema::PrimitiveType_Sqrt;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Rsqrt") == 0) { } else if (std::strcmp(node_name, "Rsqrt") == 0) {
MS_LOG(DEBUG) << "parse TfliteRsqrtParser"; MS_LOG(DEBUG) << "parse TfliteRsqrtParser";
std::unique_ptr<schema::RsqrtT> attr(new schema::RsqrtT()); std::unique_ptr<schema::RsqrtT> attr(new schema::RsqrtT());
op->primitive->value.type = schema::PrimitiveType_Rsqrt; op->primitive->value.type = schema::PrimitiveType_Rsqrt;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Square") == 0) { } else if (std::strcmp(node_name, "Square") == 0) {
MS_LOG(DEBUG) << "parse TfliteSquareParser"; MS_LOG(DEBUG) << "parse TfliteSquareParser";
std::unique_ptr<schema::SquareT> attr(new schema::SquareT()); std::unique_ptr<schema::SquareT> attr(new schema::SquareT());
op->primitive->value.type = schema::PrimitiveType_Square; op->primitive->value.type = schema::PrimitiveType_Square;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Sin") == 0) { } else if (std::strcmp(node_name, "Sin") == 0) {
MS_LOG(DEBUG) << "parse TfliteSinParser"; MS_LOG(DEBUG) << "parse TfliteSinParser";
std::unique_ptr<schema::SinT> attr(new schema::SinT()); std::unique_ptr<schema::SinT> attr(new schema::SinT());
op->primitive->value.type = schema::PrimitiveType_Sin; op->primitive->value.type = schema::PrimitiveType_Sin;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Cos") == 0) { } else if (std::strcmp(node_name, "Cos") == 0) {
MS_LOG(DEBUG) << "parse TfliteCosParser"; MS_LOG(DEBUG) << "parse TfliteCosParser";
std::unique_ptr<schema::CosT> attr(new schema::CosT()); std::unique_ptr<schema::CosT> attr(new schema::CosT());
op->primitive->value.type = schema::PrimitiveType_Cos; op->primitive->value.type = schema::PrimitiveType_Cos;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Log") == 0) { } else if (std::strcmp(node_name, "Log") == 0) {
MS_LOG(DEBUG) << "parse TfliteLogParser"; MS_LOG(DEBUG) << "parse TfliteLogParser";
std::unique_ptr<schema::LogT> attr(new schema::LogT()); std::unique_ptr<schema::LogT> attr(new schema::LogT());
op->primitive->value.type = schema::PrimitiveType_Log; op->primitive->value.type = schema::PrimitiveType_Log;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Round") == 0) { } else if (std::strcmp(node_name, "Round") == 0) {
MS_LOG(DEBUG) << "parse TfliteRoundParser"; MS_LOG(DEBUG) << "parse TfliteRoundParser";
std::unique_ptr<schema::RoundT> attr(new schema::RoundT()); std::unique_ptr<schema::RoundT> attr(new schema::RoundT());
op->primitive->value.type = schema::PrimitiveType_Round; op->primitive->value.type = schema::PrimitiveType_Round;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Ceil") == 0) { } else if (std::strcmp(node_name, "Ceil") == 0) {
MS_LOG(DEBUG) << "parse TfliteCeilParser"; MS_LOG(DEBUG) << "parse TfliteCeilParser";
std::unique_ptr<schema::CeilT> attr(new schema::CeilT()); std::unique_ptr<schema::CeilT> attr(new schema::CeilT());
op->primitive->value.type = schema::PrimitiveType_Ceil; op->primitive->value.type = schema::PrimitiveType_Ceil;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "flOOR") == 0) { } else if (std::strcmp(node_name, "flOOR") == 0) {
MS_LOG(DEBUG) << "parse TfliteFloorParser"; MS_LOG(DEBUG) << "parse TfliteFloorParser";
std::unique_ptr<schema::FloorT> attr(new schema::FloorT()); std::unique_ptr<schema::FloorT> attr(new schema::FloorT());
op->primitive->value.type = schema::PrimitiveType_Floor; op->primitive->value.type = schema::PrimitiveType_Floor;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else {
MS_LOG(ERROR) << "wrong op type";
return RET_ERROR;
} }

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
} }


STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -289,48 +238,47 @@ STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
} }


std::vector<std::string> node_name_str; std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-");
Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str(); const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Equal") == 0) { if (std::strcmp(node_name, "Equal") == 0) {
MS_LOG(DEBUG) << "parse TfliteEqualParser"; MS_LOG(DEBUG) << "parse TfliteEqualParser";
std::unique_ptr<schema::EqualT> attr(new schema::EqualT()); std::unique_ptr<schema::EqualT> attr(new schema::EqualT());
op->primitive->value.type = schema::PrimitiveType_Equal; op->primitive->value.type = schema::PrimitiveType_Equal;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "NotEqual") == 0) { } else if (std::strcmp(node_name, "NotEqual") == 0) {
MS_LOG(DEBUG) << "parse TfliteNotEqualParser"; MS_LOG(DEBUG) << "parse TfliteNotEqualParser";
std::unique_ptr<schema::NotEqualT> attr(new schema::NotEqualT()); std::unique_ptr<schema::NotEqualT> attr(new schema::NotEqualT());
op->primitive->value.type = schema::PrimitiveType_NotEqual; op->primitive->value.type = schema::PrimitiveType_NotEqual;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Greater") == 0) { } else if (std::strcmp(node_name, "Greater") == 0) {
MS_LOG(DEBUG) << "parse TfliteGreaterParser"; MS_LOG(DEBUG) << "parse TfliteGreaterParser";
std::unique_ptr<schema::GreaterT> attr(new schema::GreaterT()); std::unique_ptr<schema::GreaterT> attr(new schema::GreaterT());
op->primitive->value.type = schema::PrimitiveType_Greater; op->primitive->value.type = schema::PrimitiveType_Greater;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "GreaterEqual") == 0) { } else if (std::strcmp(node_name, "GreaterEqual") == 0) {
MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser"; MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser";
std::unique_ptr<schema::GreaterEqualT> attr(new schema::GreaterEqualT()); std::unique_ptr<schema::GreaterEqualT> attr(new schema::GreaterEqualT());
op->primitive->value.type = schema::PrimitiveType_GreaterEqual; op->primitive->value.type = schema::PrimitiveType_GreaterEqual;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Less") == 0) { } else if (std::strcmp(node_name, "Less") == 0) {
MS_LOG(DEBUG) << "parse TfliteLessParser"; MS_LOG(DEBUG) << "parse TfliteLessParser";
std::unique_ptr<schema::LessT> attr(new schema::LessT()); std::unique_ptr<schema::LessT> attr(new schema::LessT());
op->primitive->value.type = schema::PrimitiveType_Less; op->primitive->value.type = schema::PrimitiveType_Less;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "LessEqual") == 0) { } else if (std::strcmp(node_name, "LessEqual") == 0) {
MS_LOG(DEBUG) << "parse TfliteLessEqualParser"; MS_LOG(DEBUG) << "parse TfliteLessEqualParser";
std::unique_ptr<schema::LessEqualT> attr(new schema::LessEqualT()); std::unique_ptr<schema::LessEqualT> attr(new schema::LessEqualT());
op->primitive->value.type = schema::PrimitiveType_LessEqual; op->primitive->value.type = schema::PrimitiveType_LessEqual;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else {
MS_LOG(ERROR) << "wrong op type";
return RET_ERROR;
} }

for (int i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
} }


TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser()); TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser());


+ 25
- 18
mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef PREDICT_TFLITE_MATH_PARSER_H
#define PREDICT_TFLITE_MATH_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H


#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"


@@ -29,11 +30,13 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser {
public: public:
TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {} TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {}


STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };


class TfliteAddParser : public TfliteDoubleInputOpParser { class TfliteAddParser : public TfliteDoubleInputOpParser {
@@ -96,11 +99,13 @@ class TfliteSingleInputOpParser : public TfliteNodeParser {
public: public:
TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {}


STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };


class TfliteAbsParser : public TfliteSingleInputOpParser { class TfliteAbsParser : public TfliteSingleInputOpParser {
@@ -163,11 +168,13 @@ class TfliteCompareOpParser : public TfliteNodeParser {
public: public:
TfliteCompareOpParser() : TfliteNodeParser("node_name") {} TfliteCompareOpParser() : TfliteNodeParser("node_name") {}


STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };


class TfliteEqualParser : public TfliteCompareOpParser { class TfliteEqualParser : public TfliteCompareOpParser {
@@ -203,5 +210,5 @@ class TfliteLessEqualParser : public TfliteCompareOpParser {
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // PREDICT_TFLITE_MATH_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H



+ 16
- 11
mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc View File

@@ -19,14 +19,17 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <string> #include <string>
#include <map>


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -38,30 +41,32 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT>
} }


std::vector<std::string> node_name_str; std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-");
Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str(); const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "BatchToSpace") == 0) { if (std::strcmp(node_name, "BatchToSpace") == 0) {
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser"; MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser";
} else if (std::strcmp(node_name, "BatchToSpaceND") == 0) { } else if (std::strcmp(node_name, "BatchToSpaceND") == 0) {
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser"; MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser";
// in tflite
// blockShape should be a 1D tensor with dimension [spatial_dims_num]
// crops should be a 2D tensor with dimension [spatial_dims_num, 2]
} }


std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT()); std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT());


if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) {
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->blockShape)) {
MS_LOG(ERROR) << "get batchToSpace -> blockShape failed"; MS_LOG(ERROR) << "get batchToSpace -> blockShape failed";
return RET_ERROR; return RET_ERROR;
} }
if (GetTfliteData(tfliteOp->inputs[2], tfliteTensors, tfliteModelBuffer, attr->crops)) {
if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->crops)) {
MS_LOG(ERROR) << "get batchToSpace -> crops failed"; MS_LOG(ERROR) << "get batchToSpace -> crops failed";
return RET_ERROR; return RET_ERROR;
} }


op->primitive->value.type = schema::PrimitiveType_BatchToSpace; op->primitive->value.type = schema::PrimitiveType_BatchToSpace;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }




+ 8
- 5
mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef LITE_TFLITE_BATCH_TO_SPACE_PARSER_H
#define LITE_TFLITE_BATCH_TO_SPACE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H


#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"


@@ -31,8 +32,10 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantized_model) override;
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };


class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser {
@@ -43,4 +46,4 @@ class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser {
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // LITE_TFLITE_BATCH_TO_SPACE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H

+ 16
- 7
mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc View File

@@ -18,14 +18,19 @@
#include "tools/converter/parser/tflite/tflite_broadcast_to_parser.h" #include "tools/converter/parser/tflite/tflite_broadcast_to_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteBroadcastToParser";

if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -36,16 +41,20 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &
return RET_NULL_PTR; return RET_NULL_PTR;
} }


MS_LOG(DEBUG) << "parse TfliteBroadcastToParser";
std::unique_ptr<schema::BroadcastToT> attr(new schema::BroadcastToT()); std::unique_ptr<schema::BroadcastToT> attr(new schema::BroadcastToT());


if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dst_shape)) {
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dst_shape)) {
MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed"; MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed";
return RET_ERROR; return RET_ERROR;
} }


op->primitive->value.type = schema::PrimitiveType_BroadcastTo; op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }




+ 8
- 6
mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef LITE_TFLITE_BROADCAST_TO_PARSER_H
#define LITE_TFLITE_BROADCAST_TO_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H


#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"


@@ -31,11 +32,12 @@ class TfliteBroadcastToParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // LITE_TFLITE_BROADCAST_TO_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H

+ 19
- 12
mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc View File

@@ -14,18 +14,22 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */

#include "tools/converter/parser/tflite/tflite_cast_parser.h" #include "tools/converter/parser/tflite/tflite_cast_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteCastParser";

if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -36,25 +40,28 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR; return RET_NULL_PTR;
} }


MS_LOG(DEBUG) << "parse TfliteCastParser";
std::unique_ptr<schema::CastT> attr(new schema::CastT()); std::unique_ptr<schema::CastT> attr(new schema::CastT());


const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]];
const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]];
if (in_tensor == nullptr) { if (in_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null"; MS_LOG(ERROR) << "tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->srcT = dtype_map[in_tensor->type];

const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]];
attr->srcT = GetTfliteDataType(in_tensor->type);
const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]];
if (out_tensor == nullptr) { if (out_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null"; MS_LOG(ERROR) << "tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->dstT = dtype_map[out_tensor->type];
attr->dstT = GetTfliteDataType(out_tensor->type);


op->primitive->value.type = schema::PrimitiveType_Cast; op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }




+ 8
- 6
mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef LITE_TFLITE_CAST_PARSER_
#define LITE_TFLITE_CAST_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H


#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"


@@ -31,11 +32,12 @@ class TfliteCastParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // LITE_TFLITE_CAST_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H

+ 20
- 9
mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc View File

@@ -17,14 +17,20 @@
#include "tools/converter/parser/tflite/tflite_concat_parser.h" #include "tools/converter/parser/tflite/tflite_concat_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteConcatParser";

// set attr
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -35,20 +41,25 @@ STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR; return RET_NULL_PTR;
} }


MS_LOG(DEBUG) << "parse TfliteConcatParser";
std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT()); std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT());


const auto &tfliteAttr = tfliteOp->builtin_options.AsConcatenationOptions();
const auto &tfliteAttr = tflite_op->builtin_options.AsConcatenationOptions();
if (tfliteAttr == nullptr) { if (tfliteAttr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->axis = tfliteAttr->axis; attr->axis = tfliteAttr->axis;

attr->n = tfliteOp->inputs.size();
attr->n = tflite_op->inputs.size();


op->primitive->value.type = schema::PrimitiveType_Concat; op->primitive->value.type = schema::PrimitiveType_Concat;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();

for (int i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }




+ 11
- 9
mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef PREDICT_TFLITE_CONCAT_PARSER_H
#define PREDICT_TFLITE_CONCAT_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H


#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"


@@ -28,15 +29,16 @@ class TfliteConcatParser : public TfliteNodeParser {
public: public:
TfliteConcatParser() : TfliteNodeParser("Concat") {} TfliteConcatParser() : TfliteNodeParser("Concat") {}


STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // PREDICT_TFLITE_CONCAT_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H



+ 48
- 42
mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc View File

@@ -17,14 +17,19 @@
#include "tools/converter/parser/tflite/tflite_conv_parser.h" #include "tools/converter/parser/tflite/tflite_conv_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteConvParser";

if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -35,60 +40,61 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR; return RET_NULL_PTR;
} }


MS_LOG(DEBUG) << "parse TfliteConvParser";
std::unique_ptr<schema::Conv2DT> attr(new schema::Conv2DT()); std::unique_ptr<schema::Conv2DT> attr(new schema::Conv2DT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsConv2DOptions();
if (tfliteAttr == nullptr) {
const auto &tflite_attr = tflite_op->builtin_options.AsConv2DOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->group = 1; attr->group = 1;
attr->strideW = tfliteAttr->stride_w;
attr->strideH = tfliteAttr->stride_h;
attr->dilateH = tfliteAttr->dilation_h_factor;
attr->dilateW = tfliteAttr->dilation_w_factor;
attr->padMode = GetPadMode(tfliteAttr->padding);
attr->strideW = tflite_attr->stride_w;
attr->strideH = tflite_attr->stride_h;
attr->dilateH = tflite_attr->dilation_h_factor;
attr->dilateW = tflite_attr->dilation_w_factor;
attr->padMode = GetPadMode(tflite_attr->padding);
attr->format = schema::Format_NHWC; attr->format = schema::Format_NHWC;
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
attr->hasBias = true;


// get the conv op weight tensor // get the conv op weight tensor
auto weight_index = tfliteOp->inputs[1];
const auto &weight_tensor = tfliteTensors[weight_index];
auto weight_index = tflite_op->inputs[1];
const auto &weight_tensor = tflite_tensors[weight_index];
if (weight_tensor == nullptr) { if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "weight_tensor is null";
MS_LOG(ERROR) << "the weight tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
auto weight_shape = weight_tensor->shape; auto weight_shape = weight_tensor->shape;
attr->channelIn = weight_shape[KHWC_C];
attr->channelOut = weight_shape[KHWC_K];
attr->kernelW = weight_shape[KHWC_W];
attr->kernelH = weight_shape[KHWC_H];

// get the conv op bias tensor
if (tfliteOp->inputs.size() == 3) {
attr->hasBias = true;
auto bias_index = tfliteOp->inputs[2];
const auto &bias_tensor = tfliteTensors[bias_index];
if (bias_tensor == nullptr) {
MS_LOG(ERROR) << "bias_tensor is null";
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}
}
attr->channelIn = weight_shape[3];
attr->channelOut = weight_shape[0];
attr->kernelH = weight_shape[1];
attr->kernelW = weight_shape[2];


// calculate pad params // calculate pad params
auto data_index = tflite_op->inputs[0];
const auto &data_tensor = tflite_tensors[data_index];
std::vector<int> params;
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH,
attr->strideW, attr->kernelH, attr->kernelW, &params) != RET_OK) {
MS_LOG(ERROR) << "get padding params failed";
return RET_ERROR;
} else {
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);
attr->padRight = params.at(3);
}


op->primitive->value.type = schema::PrimitiveType_Conv2D; op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }




+ 11
- 9
mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef PREDICT_TFLITE_CONV_PARSER_H
#define PREDICT_TFLITE_CONV_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H


#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"


@@ -28,15 +29,16 @@ class TfliteConvParser : public TfliteNodeParser {
public: public:
TfliteConvParser() : TfliteNodeParser("Conv2D") {} TfliteConvParser() : TfliteNodeParser("Conv2D") {}


STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // PREDICT_TFLITE_CONV_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H



+ 1
- 0
mindspore/lite/tools/converter/parser/tflite/tflite_converter.h View File

@@ -19,6 +19,7 @@


#include <string> #include <string>
#include <memory> #include <memory>
#include <map>
#include "tools/converter/converter.h" #include "tools/converter/converter.h"
#include "tools/converter/parser/tflite/tflite_model_parser.h" #include "tools/converter/parser/tflite/tflite_model_parser.h"
#include "tools/converter/graphdef_transform.h" #include "tools/converter/graphdef_transform.h"


+ 44
- 18
mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc View File

@@ -17,14 +17,19 @@
#include "tools/converter/parser/tflite/tflite_deconv_parser.h" #include "tools/converter/parser/tflite/tflite_deconv_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser";

if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -35,11 +40,10 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR; return RET_NULL_PTR;
} }


MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser";
std::unique_ptr<schema::DeConv2DT> attr(new schema::DeConv2DT()); std::unique_ptr<schema::DeConv2DT> attr(new schema::DeConv2DT());
const auto &tflite_attr = tfliteOp->builtin_options.AsTransposeConvOptions();
const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions();
if (tflite_attr == nullptr) { if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str();
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;
} }


@@ -50,26 +54,48 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
attr->dilateW = 1; attr->dilateW = 1;
attr->padMode = GetPadMode(tflite_attr->padding); attr->padMode = GetPadMode(tflite_attr->padding);
attr->format = schema::Format_NHWC; attr->format = schema::Format_NHWC;
attr->activationType = schema::ActivationType_NO_ACTIVATION;
attr->hasBias = true;


// get the conv op weight tensor // get the conv op weight tensor
auto weight_index = tfliteOp->inputs[1];
const auto &weight_tensor = tfliteTensors[weight_index];
auto weight_index = tflite_op->inputs[1];
const auto &weight_tensor = tflite_tensors[weight_index];
if (weight_tensor == nullptr) { if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "weight_tensor is null";
MS_LOG(ERROR) << "the weight tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
auto weight_shape = weight_tensor->shape;
attr->channelIn = weight_shape[3];
attr->channelOut = weight_shape[0];
attr->kernelH = weight_shape[1];
attr->kernelW = weight_shape[2];

// calculate pad params
auto data_index = tflite_op->inputs[2];
const auto &data_tensor = tflite_tensors[data_index];
std::vector<int> params;
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH,
attr->strideW, attr->kernelH, attr->kernelW, &params) != RET_OK) {
MS_LOG(ERROR) << "get padding params failed";
return RET_ERROR; return RET_ERROR;
} else {
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);
attr->padRight = params.at(3);
} }
auto weight_shape = weight_tensor->shape;
attr->channelIn = weight_shape[CHWK_K];
attr->channelOut = weight_shape[CHWK_C];
attr->kernelW = weight_shape[CHWK_W];
attr->kernelH = weight_shape[CHWK_H];


op->primitive->value.type = schema::PrimitiveType_DeConv2D; op->primitive->value.type = schema::PrimitiveType_DeConv2D;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }




+ 8
- 6
mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef PREDICT_TFLITE_DECONV_PARSER_H
#define PREDICT_TFLITE_DECONV_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H


#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"


@@ -31,11 +32,12 @@ class TfliteDeConvParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_op_set, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // PREDICT_TFLITE_DECONV_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H

+ 16
- 8
mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc View File

@@ -18,14 +18,19 @@
#include "tools/converter/parser/tflite/tflite_depth_to_space_parser.h" #include "tools/converter/parser/tflite/tflite_depth_to_space_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser";

if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -36,20 +41,23 @@ STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT>
return RET_NULL_PTR; return RET_NULL_PTR;
} }


MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser";
std::unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT()); std::unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT());


const auto &tflite_attr = tfliteOp->builtin_options.AsDepthToSpaceOptions();
const auto &tflite_attr = tflite_op->builtin_options.AsDepthToSpaceOptions();
if (tflite_attr == nullptr) { if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str();
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->blockSize = tflite_attr->block_size; attr->blockSize = tflite_attr->block_size;

attr->format = schema::Format_NHWC; attr->format = schema::Format_NHWC;


op->primitive->value.type = schema::PrimitiveType_DepthToSpace; op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }




+ 8
- 6
mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H
#define LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H


#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"


@@ -31,11 +32,12 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H

+ 43
- 87
mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc View File

@@ -17,65 +17,22 @@
#include "tools/converter/parser/tflite/tflite_depthwise_conv_parser.h" #include "tools/converter/parser/tflite/tflite_depthwise_conv_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "tools/common/node_util.h" #include "tools/common/node_util.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteDepthwiseConv2DParser::ParseGroupDepthwiseConv(schema::CNodeT *op,
const std::unique_ptr<schema::DepthwiseConv2DT> &attr,
const std::unique_ptr<tflite::TensorT> &weightTensor,
TensorCache *tensor_cache) {
std::unique_ptr<schema::Conv2DT> convAttr(new schema::Conv2DT);


convAttr->format = attr->format;
convAttr->channelIn = attr->channelIn;
convAttr->channelOut = attr->channelIn * attr->channelMultiplier;
convAttr->kernelH = attr->kernelH;
convAttr->kernelW = attr->kernelW;
convAttr->strideH = attr->strideH;
convAttr->strideW = attr->strideW;
convAttr->padMode = attr->padMode;
convAttr->padUp = attr->padUp;
convAttr->padDown = attr->padDown;
convAttr->padLeft = attr->padLeft;
convAttr->padRight = attr->padRight;
convAttr->dilateH = attr->dilateH;
convAttr->dilateW = attr->dilateW;
convAttr->hasBias = attr->hasBias;
convAttr->activationType = attr->activationType;

auto weightTensorIndex = tensor_cache->FindTensor(weightTensor->name);
if (weightTensorIndex >= 0 && weightTensorIndex < tensor_cache->GetCachedTensor().size()) {
auto liteWeightTensor = tensor_cache->GetCachedTensor()[weightTensorIndex];
if (liteWeightTensor->dataType == TypeId::kNumberTypeUInt8) {
// convert weight format KHWC -> CHWK
auto status = TransFilterFormat<uint8_t>(liteWeightTensor, kKHWC2CHWK);
if (status != RET_OK) {
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
return RET_ERROR;
}
}

if (liteWeightTensor->dataType == kNumberTypeFloat32 || liteWeightTensor->dataType == kNumberTypeFloat) {
// convert weight format KHWC -> CHWK
auto status = TransFilterFormat<float>(liteWeightTensor, kKHWC2CHWK);
if (status != RET_OK) {
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
return RET_ERROR;
}
}
}

op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = convAttr.release();
return RET_OK;
}


STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser";

if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -86,7 +43,6 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
return RET_NULL_PTR; return RET_NULL_PTR;
} }


MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser";
std::unique_ptr<schema::DepthwiseConv2DT> attr(new schema::DepthwiseConv2DT()); std::unique_ptr<schema::DepthwiseConv2DT> attr(new schema::DepthwiseConv2DT());
const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions(); const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions();
if (tflite_attr == nullptr) { if (tflite_attr == nullptr) {
@@ -100,15 +56,20 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
attr->padMode = GetPadMode(tflite_attr->padding); attr->padMode = GetPadMode(tflite_attr->padding);
attr->format = schema::Format_NHWC; attr->format = schema::Format_NHWC;
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
// get the conv op weight tensor
auto input_index = tflite_op->inputs[0];
const auto &input_tenosr = tflite_tensors[input_index];
if (input_tenosr == nullptr) {
MS_LOG(ERROR) << "the first input is null";
attr->hasBias = true;
attr->channelMultiplier = tflite_attr->depth_multiplier;

// get the data tensor
auto data_index = tflite_op->inputs[1];
const auto &data_tensor = tflite_tensors[data_index];
if (data_tensor == nullptr) {
MS_LOG(ERROR) << "the data tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
auto input_shape = input_tenosr->shape;
auto data_shape = data_tensor->shape;
attr->channelIn = data_shape[3];


// get the weight tensor
auto weight_index = tflite_op->inputs[1]; auto weight_index = tflite_op->inputs[1];
const auto &weight_tensor = tflite_tensors[weight_index]; const auto &weight_tensor = tflite_tensors[weight_index];
if (weight_tensor == nullptr) { if (weight_tensor == nullptr) {
@@ -116,38 +77,33 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
return RET_NULL_PTR; return RET_NULL_PTR;
} }
auto weight_shape = weight_tensor->shape; auto weight_shape = weight_tensor->shape;
attr->channelIn = input_shape[KHWC_C];
attr->channelMultiplier = tflite_attr->depth_multiplier;
attr->kernelH = weight_shape[KHWC_H];
attr->kernelW = weight_shape[KHWC_W];

std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};

if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse weight failed";
attr->kernelH = weight_shape[1];
attr->kernelW = weight_shape[2];

// calculate pad params
std::vector<int> params;
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW,
attr->kernelH, attr->kernelW, &params) != RET_OK) {
MS_LOG(ERROR) << "get padding params failed";
return RET_ERROR; return RET_ERROR;
}

if (tflite_op->inputs.size() == 3) {
attr->hasBias = true;
auto bias_index = tflite_op->inputs[2];
const auto &bias_tensor = tflite_tensors[bias_index];
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}
}

if (attr->channelMultiplier > 1) {
if (RET_OK != ParseGroupDepthwiseConv(op, attr, weight_tensor, tensor_cache)) {
MS_LOG(ERROR) << "Parse Group DepthwiseConv failed";
return RET_ERROR;
}
} else { } else {
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
op->primitive->value.value = attr.release();
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);
attr->padRight = params.at(3);
} }

op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }




+ 11
- 14
mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H
#define PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTHWISE_CONV_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTHWISE_CONV_PARSER_H


#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"


@@ -28,20 +29,16 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser {
public: public:
TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {}


STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) override;

private:
STATUS ParseGroupDepthwiseConv(schema::CNodeT *op,
const std::unique_ptr<schema::DepthwiseConv2DT> &attr,
const std::unique_ptr<tflite::TensorT> &weightTensor,
TensorCache *tensor_cache);
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // PREDICT_TFLITE_CONV_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H



+ 22
- 19
mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc View File

@@ -16,15 +16,20 @@
#include "tools/converter/parser/tflite/tflite_dequantize_parser.h" #include "tools/converter/parser/tflite/tflite_dequantize_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "tools/common/node_util.h" #include "tools/common/node_util.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteDequantizeNParser";

if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -35,32 +40,30 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
return RET_NULL_PTR; return RET_NULL_PTR;
} }


MS_LOG(DEBUG) << "parse TfliteDequantizeNParser";
std::unique_ptr<schema::CastT> attr(new schema::CastT); std::unique_ptr<schema::CastT> attr(new schema::CastT);


// get the dequantize input tensor // get the dequantize input tensor
const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]];
const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]];
if (in_tensor == nullptr) { if (in_tensor == nullptr) {
MS_LOG(ERROR) << "weight_tensor is null";
MS_LOG(ERROR) << "input tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->srcT = dtype_map[in_tensor->type];

const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]];
attr->srcT = GetTfliteDataType(in_tensor->type);
const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]];
if (out_tensor == nullptr) { if (out_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null";
MS_LOG(ERROR) << "output tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->dstT = dtype_map[out_tensor->type];
std::vector<tflite::TensorT *> weight_tensors{in_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
attr->dstT = GetTfliteDataType(out_tensor->type);


op->primitive->value.type = schema::PrimitiveType_Fp16Cast; op->primitive->value.type = schema::PrimitiveType_Fp16Cast;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return 0;

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
} }


TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser()); TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser());


+ 11
- 8
mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h View File

@@ -13,11 +13,12 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef LITE_TFLITE_DEQUANTIZE_PARSER_H
#define LITE_TFLITE_DEQUANTIZE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H


#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"


@@ -27,13 +28,15 @@ class TfliteDequantizeParser : public TfliteNodeParser {
public: public:
TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {} TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {}


STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // LITE_TFLITE_DEQUANTIZE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H

+ 9
- 8
mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc View File

@@ -17,16 +17,17 @@
#include "tools/converter/parser/tflite/tflite_expand_dims_parser.h" #include "tools/converter/parser/tflite/tflite_expand_dims_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -40,7 +41,7 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
MS_LOG(DEBUG) << "parse TfliteExpandDimsParser"; MS_LOG(DEBUG) << "parse TfliteExpandDimsParser";
std::unique_ptr<schema::ExpandDimsT> attr(new schema::ExpandDimsT()); std::unique_ptr<schema::ExpandDimsT> attr(new schema::ExpandDimsT());


const auto &tflite_attr = tfliteOp->builtin_options.AsExpandDimsOptions();
const auto &tflite_attr = tflite_op->builtin_options.AsExpandDimsOptions();
if (tflite_attr == nullptr) { if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;


+ 11
- 9
mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h View File

@@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef PREDICT_TFLITE_EXPAND_DIMS_PARSER_H
#define PREDICT_TFLITE_EXPAND_DIMS_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H


#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"


@@ -28,15 +29,16 @@ class TfliteExpandDimsParser : public TfliteNodeParser {
public: public:
TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {} TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {}


STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // PREDICT_TFLITE_EXPAND_DIMS_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H



+ 0
- 75
mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc View File

@@ -1,75 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_fakequant_parser.h"
#include <vector>
#include <memory>

namespace mindspore {
namespace lite {
STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser";
std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT());

auto weight_index = tfliteOp->inputs[1];
const auto &weight_tensor = tfliteTensors[weight_index];
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "weight_tensor is null";
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}

if (tfliteOp->inputs.size() == 3) {
attr->hasBias = true;
auto bias_index = tfliteOp->inputs[2];
const auto &bias_tensor = tfliteTensors[bias_index];
if (bias_tensor == nullptr) {
MS_LOG(ERROR) << "bias_tensor is null";
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}
}
attr->axis = 1;

op->primitive->value.type = schema::PrimitiveType_FullConnection;
op->primitive->value.value = attr.release();
return RET_OK;
}

TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFakeQuantParser());
} // namespace lite
} // namespace mindspore

+ 0
- 39
mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h View File

@@ -1,39 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TFLITE_FAKEQUANT_PARSER_H
#define LITE_TFLITE_FAKEQUANT_PARSER_H

#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

namespace mindspore {
namespace lite {
class TfliteFakeQuantParser : public TfliteNodeParser {
public:
TfliteFakeQuantParser() : TfliteNodeParser("FakeQuant") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_TFLITE_FAKEQUANT_PARSER_H

+ 19
- 12
mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc View File

@@ -14,19 +14,22 @@
* limitations under the License. * limitations under the License.
*/ */


#include "tools/converter/parser/tflite/tflite_fill_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "tools/converter/parser/tflite/tflite_fill_parser.h"
#include <map>


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteFillParser";

if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@@ -37,18 +40,22 @@ STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR; return RET_NULL_PTR;
} }


MS_LOG(DEBUG) << "parse TfliteFillParser";
std::unique_ptr<schema::FillT> attr(new schema::FillT()); std::unique_ptr<schema::FillT> attr(new schema::FillT());


if (tfliteOp->inputs.size() > 1) {
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dims)) {
MS_LOG(ERROR) << "get Fill -> dims failed";
if (tflite_op->inputs.size() > 1) {
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dims)) {
MS_LOG(ERROR) << "get fill -> dims failed";
return RET_ERROR; return RET_ERROR;
} }
} }


op->primitive->value.type = schema::PrimitiveType_Fill; op->primitive->value.type = schema::PrimitiveType_Fill;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }




Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save