| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" | |||
| #include <iostream> | |||
| #include "common/common_test.h" | |||
| namespace mindspore { | |||
| class TestTfliteParserL2Norm : public TestTfliteParser { | |||
| public: | |||
| TestTfliteParserL2Norm() = default; | |||
| void SetUp() override { meta_graph = LoadAndConvert("./l2norm.tflite", ""); } | |||
| }; | |||
| TEST_F(TestTfliteParserL2Norm, OpType) { | |||
| ASSERT_NE(meta_graph, nullptr); | |||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_L2Norm) << "wrong Op Type"; | |||
| } | |||
| TEST_F(TestTfliteParserL2Norm, AttrValue) { | |||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsL2Norm(), nullptr); | |||
| auto val = meta_graph->nodes.front()->primitive->value.AsL2Norm(); | |||
| ASSERT_EQ(val->epsilon, 0.0); | |||
| std::vector<int32_t> axis = {0, 1, 2, 3}; | |||
| ASSERT_EQ(val->axis, axis); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -92,8 +92,6 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||
| tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||
| tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); | |||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||
| return RET_OK; | |||
| @@ -0,0 +1,75 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * distributed under the License is distributed on an AS | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tflite/tflite_l2norm_parser.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <map> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteL2NormParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||
| schema::CNodeT *op, | |||
| std::vector<int32_t> *tensors_id, | |||
| std::vector<schema::Format> *tensors_format, | |||
| std::map<int, int> *tensors_id_map) { | |||
| MS_LOG(DEBUG) << "parse TfliteL2NormParser"; | |||
| // set attr | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::L2NormT> attr(new schema::L2NormT()); | |||
| auto data_index = tflite_op->inputs[0]; | |||
| const auto &data_tensor = tflite_tensors[data_index]; | |||
| if (data_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "the input tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto ndim = data_tensor->shape.size(); | |||
| std::vector<int32_t> axis; | |||
| axis.reserve(ndim); | |||
| for (int i = 0; i < ndim; i++) { | |||
| axis.emplace_back(i); | |||
| } | |||
| attr->axis = axis; | |||
| attr->epsilon = 0.0f; | |||
| op->primitive->value.type = schema::PrimitiveType_L2Norm; | |||
| op->primitive->value.value = attr.release(); | |||
| // set input | |||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| TfliteNodeRegister g_tfliteL2NormParser("L2_NORMALIZATION", new TfliteL2NormParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_L2NORM_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_L2NORM_PARSER_H | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TfliteL2NormParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteL2NormParser() : TfliteNodeParser("L2_NORMALIZATION") {} | |||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||
| schema::CNodeT *op, | |||
| std::vector<int32_t> *tensors_id, | |||
| std::vector<schema::Format> *tensors_format, | |||
| std::map<int, int> *tensors_id_map) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_L2NORM_PARSER_H | |||
| @@ -96,6 +96,11 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||
| for (const auto &tflite_op : tflite_subgraph->operators) { | |||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | |||
| auto op_type = GetMSOpType(tflite_op_type); | |||
| if (op_type == "CUSTOM") { | |||
| auto custom_type = (tflite_model->operator_codes[tflite_op->opcode_index])->custom_code; | |||
| MS_LOG(ERROR) << "CUSTOM op is not supported, the type is " << custom_type; | |||
| return RET_ERROR; | |||
| } | |||
| std::unique_ptr<schema::CNodeT> op(new schema::CNodeT); | |||
| op->name = op_type + "-" + std::to_string(idx++); | |||
| @@ -216,7 +221,7 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> | |||
| return RET_OK; | |||
| } | |||
| STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) { | |||
| STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph) { | |||
| for (auto &op : sub_graph->nodes) { | |||
| if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| auto attr = op->primitive->value.AsDepthwiseConv2D(); | |||
| @@ -248,7 +253,6 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) { | |||
| auto weight_id = op->inputIndex[1]; | |||
| auto &weight_tensor = sub_graph->allTensors.at(weight_id); | |||
| if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) { | |||
| // convert weight format KHWC -> CHWK | |||
| auto status = TransFilterFormat<uint8_t>(weight_tensor.get(), kKHWC2CHWK); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; | |||
| @@ -256,13 +260,13 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) { | |||
| } | |||
| } | |||
| if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) { | |||
| // convert weight format KHWC -> CHWK | |||
| auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; | |||
| MS_LOG(ERROR) << "Trans filter format failed."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| weight_tensor->format = schema::Format_CHWK; | |||
| } | |||
| } | |||
| } | |||
| @@ -303,8 +307,8 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &model_file, const std::s | |||
| } | |||
| // update for depthwiseConv | |||
| if (UpdateOp(sub_graph.get()) != RET_OK) { | |||
| MS_LOG(ERROR) << "update depthwise conv failed"; | |||
| if (ConvertGroupDepthwiseOp(sub_graph.get()) != RET_OK) { | |||
| MS_LOG(ERROR) << "convert group depthwise conv failed"; | |||
| return nullptr; | |||
| } | |||
| @@ -67,7 +67,7 @@ class TfliteModelParser : public ModelParser { | |||
| STATUS GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| schema::MetaGraphT* sub_graph); | |||
| STATUS UpdateOp(schema::MetaGraphT* sub_graph); | |||
| STATUS ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph); | |||
| private: | |||
| std::vector<int32_t> tensorsId; | |||