Merge pull request !3688 from ghzl/add-deconv-parsertags/v0.7.0-beta
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "tools/converter/parser/tflite/tflite_deconv_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_op_set, | |||
| schema::CNodeT *op, | |||
| TensorCache *tensor_cache, bool quantized_model) { | |||
| MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; | |||
| std::unique_ptr<schema::DeConv2DT> attr(new schema::DeConv2DT()); | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->group = 1; | |||
| attr->strideW = tflite_attr->stride_w; | |||
| attr->strideH = tflite_attr->stride_h; | |||
| attr->dilateH = 1; | |||
| attr->dilateW = 1; | |||
| attr->padMode = GetPadMode(tflite_attr->padding); | |||
| attr->format = schema::Format_NHWC; | |||
| // get the conv op weight tensor | |||
| auto weight_index = tflite_op->inputs[1]; | |||
| const auto &weight_tensor = tflite_tensors[weight_index]; | |||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | |||
| if (RET_OK != ParseWeight(weight_tensors, tflite_model_buffer, tensor_cache, schema::Format_KHWC)) { | |||
| return RET_ERROR; | |||
| } | |||
| auto weight_shape = weight_tensor->shape; | |||
| attr->channelIn = weight_shape[KHWC_C]; | |||
| attr->channelOut = weight_shape[KHWC_K]; | |||
| attr->kernelW = weight_shape[KHWC_W]; | |||
| attr->kernelH = weight_shape[KHWC_H]; | |||
| if (op != nullptr) { | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_DeConv2D; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TfliteNodeRegister g_tfliteDeConv2DParser("DeConv2D", new TfliteDeConvParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef PREDICT_TFLITE_DECONV_PARSER_H | |||
| #define PREDICT_TFLITE_DECONV_PARSER_H | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TfliteDeConvParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteDeConvParser() : TfliteNodeParser("DeConv2D") {} | |||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_op_set, schema::CNodeT *op, | |||
| TensorCache *tensor_cache, | |||
| bool quantizedModel) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // PREDICT_TFLITE_DECONV_PARSER_H | |||
| @@ -112,10 +112,17 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT | |||
| return RET_OK; | |||
| } | |||
| STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, TensorCache *tensorCache) { | |||
| for (const auto &tfliteIndex : tflite_op->inputs) { | |||
| const auto &tflite_tensor = tflite_subgraph->tensors[tfliteIndex]; | |||
| auto op_type = GetTfliteNodeType(tflite_op, tflite_model); | |||
| std::vector<int32_t> op_inputs(tflite_op->inputs); | |||
| if (op_type == "DeConv2D") { | |||
| reverse(op_inputs.begin(), op_inputs.end()); | |||
| } | |||
| for (const auto &tflite_index : op_inputs) { | |||
| const auto &tflite_tensor = tflite_subgraph->tensors[tflite_index]; | |||
| auto tensor_name = tflite_tensor->name; | |||
| auto op = tfliteOpMap[tflite_op.get()]; | |||
| unsigned int index = tensorCache->FindTensor(tensor_name); | |||
| @@ -228,8 +235,8 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st | |||
| } | |||
| for (const auto &tflite_op : tflite_subgraph->operators) { | |||
| auto statusTmp = SetOpInputIdx(tflite_subgraph, tflite_op, &tensorCache); | |||
| if (statusTmp != RET_OK) { | |||
| auto status_tmp = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, &tensorCache); | |||
| if (status_tmp != RET_OK) { | |||
| // MS_LOGE("Set Op %s Input Index Failed!", tfliteOpMap.at(tflite_op.get())->name.c_str()); | |||
| } | |||
| } | |||
| @@ -73,7 +73,8 @@ class TfliteModelParser : public ModelParser { | |||
| schema::CNodeT *op, | |||
| TensorCache *tensorCache); | |||
| STATUS SetOpInputIdx(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| STATUS SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, TensorCache *tensorCache); | |||
| std::map<std::string, schema::CNodeT *> opMap; | |||
| @@ -55,6 +55,7 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{ | |||
| {tflite::BuiltinOperator_ARG_MAX, "Argmax"}, | |||
| {tflite::BuiltinOperator_SQUARED_DIFFERENCE, "SquaredDifference"}, | |||
| {tflite::BuiltinOperator_FAKE_QUANT, "FakeQuant"}, | |||
| {tflite::BuiltinOperator_TRANSPOSE_CONV, "DeConv2D"}, | |||
| }; | |||
| std::string GetMSOpType(tflite::BuiltinOperator tfliteOpType) { | |||