Merge pull request !3688 from ghzl/add-deconv-parsertags/v0.7.0-beta
| @@ -0,0 +1,68 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_deconv_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_op_set, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantized_model) { | |||||
| MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; | |||||
| std::unique_ptr<schema::DeConv2DT> attr(new schema::DeConv2DT()); | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->group = 1; | |||||
| attr->strideW = tflite_attr->stride_w; | |||||
| attr->strideH = tflite_attr->stride_h; | |||||
| attr->dilateH = 1; | |||||
| attr->dilateW = 1; | |||||
| attr->padMode = GetPadMode(tflite_attr->padding); | |||||
| attr->format = schema::Format_NHWC; | |||||
| // get the conv op weight tensor | |||||
| auto weight_index = tflite_op->inputs[1]; | |||||
| const auto &weight_tensor = tflite_tensors[weight_index]; | |||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | |||||
| if (RET_OK != ParseWeight(weight_tensors, tflite_model_buffer, tensor_cache, schema::Format_KHWC)) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto weight_shape = weight_tensor->shape; | |||||
| attr->channelIn = weight_shape[KHWC_C]; | |||||
| attr->channelOut = weight_shape[KHWC_K]; | |||||
| attr->kernelW = weight_shape[KHWC_W]; | |||||
| attr->kernelH = weight_shape[KHWC_H]; | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_DeConv2D; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteDeConv2DParser("DeConv2D", new TfliteDeConvParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_DECONV_PARSER_H | |||||
| #define PREDICT_TFLITE_DECONV_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteDeConvParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteDeConvParser() : TfliteNodeParser("DeConv2D") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_op_set, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_DECONV_PARSER_H | |||||
| @@ -112,10 +112,17 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, TensorCache *tensorCache) { | const std::unique_ptr<tflite::OperatorT> &tflite_op, TensorCache *tensorCache) { | ||||
| for (const auto &tfliteIndex : tflite_op->inputs) { | |||||
| const auto &tflite_tensor = tflite_subgraph->tensors[tfliteIndex]; | |||||
| auto op_type = GetTfliteNodeType(tflite_op, tflite_model); | |||||
| std::vector<int32_t> op_inputs(tflite_op->inputs); | |||||
| if (op_type == "DeConv2D") { | |||||
| reverse(op_inputs.begin(), op_inputs.end()); | |||||
| } | |||||
| for (const auto &tflite_index : op_inputs) { | |||||
| const auto &tflite_tensor = tflite_subgraph->tensors[tflite_index]; | |||||
| auto tensor_name = tflite_tensor->name; | auto tensor_name = tflite_tensor->name; | ||||
| auto op = tfliteOpMap[tflite_op.get()]; | auto op = tfliteOpMap[tflite_op.get()]; | ||||
| unsigned int index = tensorCache->FindTensor(tensor_name); | unsigned int index = tensorCache->FindTensor(tensor_name); | ||||
| @@ -228,8 +235,8 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st | |||||
| } | } | ||||
| for (const auto &tflite_op : tflite_subgraph->operators) { | for (const auto &tflite_op : tflite_subgraph->operators) { | ||||
| auto statusTmp = SetOpInputIdx(tflite_subgraph, tflite_op, &tensorCache); | |||||
| if (statusTmp != RET_OK) { | |||||
| auto status_tmp = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, &tensorCache); | |||||
| if (status_tmp != RET_OK) { | |||||
| // MS_LOGE("Set Op %s Input Index Failed!", tfliteOpMap.at(tflite_op.get())->name.c_str()); | // MS_LOGE("Set Op %s Input Index Failed!", tfliteOpMap.at(tflite_op.get())->name.c_str()); | ||||
| } | } | ||||
| } | } | ||||
| @@ -73,7 +73,8 @@ class TfliteModelParser : public ModelParser { | |||||
| schema::CNodeT *op, | schema::CNodeT *op, | ||||
| TensorCache *tensorCache); | TensorCache *tensorCache); | ||||
| STATUS SetOpInputIdx(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| STATUS SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, TensorCache *tensorCache); | const std::unique_ptr<tflite::OperatorT> &tflite_op, TensorCache *tensorCache); | ||||
| std::map<std::string, schema::CNodeT *> opMap; | std::map<std::string, schema::CNodeT *> opMap; | ||||
| @@ -55,6 +55,7 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{ | |||||
| {tflite::BuiltinOperator_ARG_MAX, "Argmax"}, | {tflite::BuiltinOperator_ARG_MAX, "Argmax"}, | ||||
| {tflite::BuiltinOperator_SQUARED_DIFFERENCE, "SquaredDifference"}, | {tflite::BuiltinOperator_SQUARED_DIFFERENCE, "SquaredDifference"}, | ||||
| {tflite::BuiltinOperator_FAKE_QUANT, "FakeQuant"}, | {tflite::BuiltinOperator_FAKE_QUANT, "FakeQuant"}, | ||||
| {tflite::BuiltinOperator_TRANSPOSE_CONV, "DeConv2D"}, | |||||
| }; | }; | ||||
| std::string GetMSOpType(tflite::BuiltinOperator tfliteOpType) { | std::string GetMSOpType(tflite::BuiltinOperator tfliteOpType) { | ||||