| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" | |||||
| #include <iostream> | |||||
| #include "common/common_test.h" | |||||
| namespace mindspore { | |||||
| class TestTfliteParserL2Norm : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserL2Norm() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./l2norm.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserL2Norm, OpType) { | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_L2Norm) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserL2Norm, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsL2Norm(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsL2Norm(); | |||||
| ASSERT_EQ(val->epsilon, 0.0); | |||||
| std::vector<int32_t> axis = {0, 1, 2, 3}; | |||||
| ASSERT_EQ(val->axis, axis); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -92,8 +92,6 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | ||||
| tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); | tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | ||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -0,0 +1,75 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * distributed under the License is distributed on an AS | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tflite/tflite_l2norm_parser.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteL2NormParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| MS_LOG(DEBUG) << "parse TfliteL2NormParser"; | |||||
| // set attr | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::L2NormT> attr(new schema::L2NormT()); | |||||
| auto data_index = tflite_op->inputs[0]; | |||||
| const auto &data_tensor = tflite_tensors[data_index]; | |||||
| if (data_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the input tensor is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto ndim = data_tensor->shape.size(); | |||||
| std::vector<int32_t> axis; | |||||
| axis.reserve(ndim); | |||||
| for (int i = 0; i < ndim; i++) { | |||||
| axis.emplace_back(i); | |||||
| } | |||||
| attr->axis = axis; | |||||
| attr->epsilon = 0.0f; | |||||
| op->primitive->value.type = schema::PrimitiveType_L2Norm; | |||||
| op->primitive->value.value = attr.release(); | |||||
| // set input | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteL2NormParser("L2_NORMALIZATION", new TfliteL2NormParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_L2NORM_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_L2NORM_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteL2NormParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteL2NormParser() : TfliteNodeParser("L2_NORMALIZATION") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_L2NORM_PARSER_H | |||||
| @@ -96,6 +96,11 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||||
| for (const auto &tflite_op : tflite_subgraph->operators) { | for (const auto &tflite_op : tflite_subgraph->operators) { | ||||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | ||||
| auto op_type = GetMSOpType(tflite_op_type); | auto op_type = GetMSOpType(tflite_op_type); | ||||
| if (op_type == "CUSTOM") { | |||||
| auto custom_type = (tflite_model->operator_codes[tflite_op->opcode_index])->custom_code; | |||||
| MS_LOG(ERROR) << "CUSTOM op is not supported, the type is " << custom_type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::unique_ptr<schema::CNodeT> op(new schema::CNodeT); | std::unique_ptr<schema::CNodeT> op(new schema::CNodeT); | ||||
| op->name = op_type + "-" + std::to_string(idx++); | op->name = op_type + "-" + std::to_string(idx++); | ||||
| @@ -216,7 +221,7 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) { | |||||
| STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph) { | |||||
| for (auto &op : sub_graph->nodes) { | for (auto &op : sub_graph->nodes) { | ||||
| if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { | if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { | ||||
| auto attr = op->primitive->value.AsDepthwiseConv2D(); | auto attr = op->primitive->value.AsDepthwiseConv2D(); | ||||
| @@ -248,7 +253,6 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) { | |||||
| auto weight_id = op->inputIndex[1]; | auto weight_id = op->inputIndex[1]; | ||||
| auto &weight_tensor = sub_graph->allTensors.at(weight_id); | auto &weight_tensor = sub_graph->allTensors.at(weight_id); | ||||
| if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) { | if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) { | ||||
| // convert weight format KHWC -> CHWK | |||||
| auto status = TransFilterFormat<uint8_t>(weight_tensor.get(), kKHWC2CHWK); | auto status = TransFilterFormat<uint8_t>(weight_tensor.get(), kKHWC2CHWK); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; | MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; | ||||
| @@ -256,13 +260,13 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) { | |||||
| } | } | ||||
| } | } | ||||
| if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) { | if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) { | ||||
| // convert weight format KHWC -> CHWK | |||||
| auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK); | auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; | |||||
| MS_LOG(ERROR) << "Trans filter format failed."; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| weight_tensor->format = schema::Format_CHWK; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -303,8 +307,8 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &model_file, const std::s | |||||
| } | } | ||||
| // update for depthwiseConv | // update for depthwiseConv | ||||
| if (UpdateOp(sub_graph.get()) != RET_OK) { | |||||
| MS_LOG(ERROR) << "update depthwise conv failed"; | |||||
| if (ConvertGroupDepthwiseOp(sub_graph.get()) != RET_OK) { | |||||
| MS_LOG(ERROR) << "convert group depthwise conv failed"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -67,7 +67,7 @@ class TfliteModelParser : public ModelParser { | |||||
| STATUS GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | STATUS GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | ||||
| schema::MetaGraphT* sub_graph); | schema::MetaGraphT* sub_graph); | ||||
| STATUS UpdateOp(schema::MetaGraphT* sub_graph); | |||||
| STATUS ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph); | |||||
| private: | private: | ||||
| std::vector<int32_t> tensorsId; | std::vector<int32_t> tensorsId; | ||||