From a63ee29de8e7f7bf15bba4e522f73692b92125e6 Mon Sep 17 00:00:00 2001 From: lyvette Date: Mon, 17 Aug 2020 18:09:00 +0800 Subject: [PATCH] fix bug when op is custom fix deconv bug --- .../parser/tflite/test_data/l2norm.tflite | Bin 0 -> 516 bytes .../tflite/tflite_l2norm_parser_test.cc | 41 ++++++++++ .../parser/tflite/tflite_deconv_parser.cc | 2 - .../parser/tflite/tflite_l2norm_parser.cc | 75 ++++++++++++++++++ .../parser/tflite/tflite_l2norm_parser.h | 43 ++++++++++ .../parser/tflite/tflite_model_parser.cc | 16 ++-- .../parser/tflite/tflite_model_parser.h | 2 +- 7 files changed, 170 insertions(+), 9 deletions(-) create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/l2norm.tflite create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_l2norm_parser_test.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/l2norm.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/l2norm.tflite new file mode 100644 index 0000000000000000000000000000000000000000..0a9a8638763244dc68bcd1269c043c0667bc95eb GIT binary patch literal 516 zcmZvZJ4(c06os!EXNH*&2ep_Y#S~Lm1s~W8iXhB@hJ7>zNOvDdkaHcba?Dd0NEPV^OEMi=SNekd_700X&bh)0~ozqSNS*+{X^Z zYI$`tKQ9aJ)ovXz32^-9KVTx?A((>&knc_K1!^FF3gsT+YnMDpF%z^J@;~IONR#kq zI)Hc=P1E#udp)ypdqbO#;Sdu&`-H1~ +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserL2Norm : public TestTfliteParser { + public: + TestTfliteParserL2Norm() = default; + void SetUp() override { meta_graph = LoadAndConvert("./l2norm.tflite", ""); } +}; + +TEST_F(TestTfliteParserL2Norm, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_L2Norm) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserL2Norm, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsL2Norm(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsL2Norm(); + ASSERT_EQ(val->epsilon, 0.0); + std::vector axis = {0, 1, 2, 3}; + ASSERT_EQ(val->axis, axis); +} + +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc index f994f90479..565bceb8a4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc @@ -92,8 +92,6 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tflit tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, - tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc new file mode 100644 index 0000000000..8ef185d907 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc @@ -0,0 +1,75 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include "tools/converter/parser/tflite/tflite_l2norm_parser.h" +#include +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteL2NormParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteL2NormParser"; + + // set attr + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr(new schema::L2NormT()); + auto data_index = tflite_op->inputs[0]; + const auto &data_tensor = tflite_tensors[data_index]; + if (data_tensor == nullptr) { + MS_LOG(ERROR) << "the input tensor is null"; + return RET_NULL_PTR; + } + + auto ndim = data_tensor->shape.size(); + std::vector axis; + axis.reserve(ndim); + for (int i = 0; i < ndim; i++) { + axis.emplace_back(i); + } + attr->axis = axis; + attr->epsilon = 0.0f; + + op->primitive->value.type = schema::PrimitiveType_L2Norm; + op->primitive->value.value = attr.release(); + + // set input + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + return RET_OK; +} + +TfliteNodeRegister g_tfliteL2NormParser("L2_NORMALIZATION", new TfliteL2NormParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h new file mode 100644 index 0000000000..ea9b902be9 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h @@ -0,0 +1,43 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_L2NORM_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_L2NORM_PARSER_H + +#include +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteL2NormParser : public TfliteNodeParser { + public: + TfliteL2NormParser() : TfliteNodeParser("L2_NORMALIZATION") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_L2NORM_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 0086a3ec0d..b7b11e4577 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -96,6 +96,11 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflit for (const auto &tflite_op : tflite_subgraph->operators) { auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; auto op_type = GetMSOpType(tflite_op_type); + if (op_type == "CUSTOM") { + auto custom_type = (tflite_model->operator_codes[tflite_op->opcode_index])->custom_code; + MS_LOG(ERROR) << "CUSTOM op is not supported, the type is " << custom_type; + return RET_ERROR; + } std::unique_ptr op(new schema::CNodeT); op->name = op_type + "-" + std::to_string(idx++); @@ -216,7 +221,7 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr return RET_OK; } -STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) { +STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph) { for (auto &op : sub_graph->nodes) { if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { auto attr = op->primitive->value.AsDepthwiseConv2D(); @@ -248,7 +253,6 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) { auto weight_id = op->inputIndex[1]; auto &weight_tensor = sub_graph->allTensors.at(weight_id); if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) { - // convert weight format KHWC -> CHWK auto status = TransFilterFormat(weight_tensor.get(), kKHWC2CHWK); if (status != RET_OK) { MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; @@ -256,13 +260,13 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) { } } if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) { - // convert weight format KHWC -> CHWK auto status = TransFilterFormat(weight_tensor.get(), kKHWC2CHWK); if (status != RET_OK) { - MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; + MS_LOG(ERROR) << "Trans filter format failed."; return RET_ERROR; } } + weight_tensor->format = schema::Format_CHWK; } } } @@ -303,8 +307,8 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &model_file, const std::s } // update for depthwiseConv - if (UpdateOp(sub_graph.get()) != RET_OK) { - MS_LOG(ERROR) << "update depthwise conv failed"; + if (ConvertGroupDepthwiseOp(sub_graph.get()) != RET_OK) { + MS_LOG(ERROR) << "convert group depthwise conv failed"; return nullptr; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index 4c8eba0728..833ef1cbec 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -67,7 +67,7 @@ class TfliteModelParser : public ModelParser { STATUS GetGraphInfo(const std::unique_ptr &tflite_subgraph, schema::MetaGraphT* sub_graph); - STATUS UpdateOp(schema::MetaGraphT* sub_graph); + STATUS ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph); private: std::vector tensorsId;