Browse Source

!8426 rewrite tflite parsers

From: @cjh9368
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
47f86afa11
8 changed files with 350 additions and 27 deletions
  1. +1
    -1
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_parsers_test_utils.cc
  2. +3
    -4
      mindspore/lite/tools/converter/model_parser.h
  3. +273
    -0
      mindspore/lite/tools/converter/parser/tflite/model_parser_for_tflite.cc
  4. +44
    -0
      mindspore/lite/tools/converter/parser/tflite/model_parser_for_tflite.h
  5. +6
    -8
      mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h
  6. +7
    -4
      mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h
  7. +11
    -5
      mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc
  8. +5
    -5
      mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h

+ 1
- 1
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_parsers_test_utils.cc View File

@@ -23,7 +23,7 @@ namespace mindspore {

schema::MetaGraphT *TestTfliteParser::LoadAndConvert(const string &model_path, const string &weight_path) {
lite::TfliteModelParser parser;
meta_graph = parser.ParseToFb(model_path, weight_path);
meta_graph = parser.ParseToFb(model_path, weight_path, schema::QuantType_QUANT_NONE);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Parse to metaGraph return nullptr";
return nullptr;


+ 3
- 4
mindspore/lite/tools/converter/model_parser.h View File

@@ -28,12 +28,11 @@ namespace mindspore::lite {
using namespace schema;
class ModelParser {
public:
ModelParser() {}
ModelParser() = default;

virtual ~ModelParser() {}
virtual ~ModelParser() = default;

FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) {
virtual FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) {
auto *meta_graph = ParseToFb(modelFile, weightFile, quantType);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "parse model to fb failed";


+ 273
- 0
mindspore/lite/tools/converter/parser/tflite/model_parser_for_tflite.cc View File

@@ -0,0 +1,273 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/model_parser_for_tflite.h"
#include <string>
#include <vector>
#include <memory>
#include <algorithm>
#include "src/param_value_lite.h"

namespace mindspore::lite {

FuncGraphPtr ModelParserForTflite::Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
// load graph
tfliteModel = ReadTfliteModel(modelFile.c_str());
if (tfliteModel == nullptr) {
MS_LOG(ERROR) << "read tflite model failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
return nullptr;
}

if (tfliteModel->subgraphs.size() != 1) {
MS_LOG(ERROR) << "read tflite model subgraphs failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
return nullptr;
}
funcGraphPtr = std::make_shared<FuncGraph>();

auto status = ConvertGraphInputs();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph inputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}

status = ConvertOps();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert ops failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}

status = ConvertGraphOutputs();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph outputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
return funcGraphPtr;
}

STATUS ModelParserForTflite::ConvertOps() {
const auto &tfliteSubgraph = tfliteModel->subgraphs.front();
const auto &tfliteModelBuffers = tfliteModel->buffers;
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
STATUS status = RET_OK;
int opIdx = 0;
for (auto &op : tfliteSubgraph->operators) {
auto tfliteOpType = (tfliteModel->operator_codes[op->opcode_index])->builtin_code;
auto opType = GetMSOpType(tfliteOpType);

// parse primitive
auto nodeParser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType);
if (nodeParser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(opType);
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
continue;
}
PrimitiveC *primitiveC = nullptr;
if (status == RET_OK) {
status = nodeParser->Parse(op, tfliteModel, primitiveC);
if (status != RET_OK) {
if (status == RET_NOT_FIND_OP) {
opType = (opType != "Custom" ? opType : (tfliteModel->operator_codes[op->opcode_index])->custom_code);
NoSupportOp::GetInstance()->InsertOp(opType);
} else {
MS_LOG(ERROR) << "node " << opType.c_str() << " parser failed";
}
continue;
}

std::vector<AnfNodePtr> opInputs = {NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC))};
// parse inputs
for (auto inputIdx : op->inputs) {
const auto &inputTensor = tfliteSubgraph->tensors[inputIdx];
if (nodes.find(inputIdx) != nodes.end()) {
opInputs.emplace_back(nodes.at(inputIdx));
continue;
}
// const tensor
if (tfliteModelBuffers.at(inputTensor->buffer)->data.empty()) {
ParameterPtr parameter;
ConvertConstTensor(inputTensor.get(), parameter);
opInputs.emplace_back(parameter);
nodes.insert(std::pair(inputIdx, parameter));
continue;
}
MS_LOG(ERROR) << "tensor" << inputIdx << " is neither a node output nor a weight tensor.";
return RET_ERROR;
}
auto newCNode = funcGraphPtr->NewCNode(opInputs);
newCNode->set_fullname_with_scope(opType + "-" + std::to_string(opIdx++));

// parse outputs
status = ConvertOutputTensor(op.get(), newCNode);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert output tensors for " << newCNode->fullname_with_scope() << " failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return status;
}
}
}
return status;
}

STATUS ModelParserForTflite::ConvertGraphInputs() {
const auto &tfliteSubgraph = tfliteModel->subgraphs.front();
for (auto tfliteGraphInput : tfliteSubgraph->inputs) {
if (tfliteGraphInput < 0) {
tfliteGraphInput = tfliteGraphInput + tfliteSubgraph->tensors.size();
}
auto parameter = funcGraphPtr->add_parameter();
const auto &tensor = tfliteSubgraph->tensors.at(tfliteGraphInput);
std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
parameter->set_abstract(abstract_tensor);
parameter->set_name("graph_input_" + std::to_string(tfliteGraphInput) + "_parameter");
nodes.insert(std::pair(tfliteGraphInput, parameter));
}
return RET_OK;
}
STATUS ModelParserForTflite::ConvertGraphOutputs() {
const auto &tfliteSubgraph = tfliteModel->subgraphs.front();
if (tfliteSubgraph->outputs.size() > 1) {
std::vector<AnfNodePtr> make_tuple_inputs;
auto make_tuple_prim_ptr = GetMakeTuplePrim();
if (make_tuple_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr";
return RET_NULL_PTR;
}
auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
make_tuple_inputs.emplace_back(make_tuple_prim);
for (auto outputNode : tfliteSubgraph->outputs) {
auto cnode = nodes.at(outputNode);
if (nullptr == cnode) {
MS_LOG(ERROR) << "Can't find input node.";
return RET_NOT_FIND_OP;
}
make_tuple_inputs.emplace_back(cnode);
}
auto make_tuple_cnode = funcGraphPtr->NewCNode(make_tuple_inputs);
make_tuple_cnode->set_fullname_with_scope("return tuple");

std::vector<AnfNodePtr> op_inputs;
auto return_prim_ptr = GetReturnPrim();
if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
return RET_NULL_PTR;
}
auto value_node = NewValueNode(return_prim_ptr);
op_inputs.emplace_back(value_node);
op_inputs.emplace_back(make_tuple_cnode);
auto cnode = funcGraphPtr->NewCNode(op_inputs);
cnode->set_fullname_with_scope("return");
funcGraphPtr->set_return(cnode);
} else {
auto returnPrim = GetReturnPrim();
if (returnPrim == nullptr) {
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
return RET_NULL_PTR;
}
auto valueNode = NewValueNode(returnPrim);
std::vector<AnfNodePtr> opInputs{valueNode};
auto cnode = nodes.at(tfliteSubgraph->outputs.front());
if (nullptr == cnode) {
MS_LOG(ERROR) << "Can't find input node.";
return RET_NOT_FIND_OP;
}
opInputs.emplace_back(cnode);
auto returnCnode = funcGraphPtr->NewCNode(opInputs);
returnCnode->set_fullname_with_scope("return");
funcGraphPtr->set_return(returnCnode);
}
return RET_OK;
}

STATUS ModelParserForTflite::ConvertConstTensor(const tflite::TensorT *tensor, ParameterPtr parameter) {
parameter = funcGraphPtr->add_parameter();
const auto &tfliteModelBuffers = tfliteModel->buffers;
auto type_id = static_cast<TypeId>(tensor->type);
auto type_ptr = TypeIdToType(type_id);
std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
parameter->set_abstract(abstract_tensor);
parameter->set_name("const_" + std::to_string(nodes.size()) + "_parameter");

ParamValueLitePtr paramValue = std::make_shared<ParamValueLite>();
MS_ASSERT(paramValue != nullptr);
paramValue->set_tensor_shape(tensor->shape);
paramValue->set_tensor_type(GetTfliteDataType(tensor->type));
paramValue->set_format(schema::Format::Format_NHWC);
const auto &data = tfliteModelBuffers.at(tensor->buffer)->data;
if (!data.empty()) {
auto size = data.size();
char *tensor_data = new (std::nothrow) char[size];
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "new char[] failed";
return RET_MEMORY_FAILED;
}
std::memcpy(tensor_data, data.data(), size);
paramValue->set_tensor_addr(tensor_data);
paramValue->set_tensor_size(size);
parameter->set_default_param(paramValue);
}
return RET_OK;
}

STATUS ModelParserForTflite::ConvertOutputTensor(const tflite::OperatorT *op, CNodePtr dstCNode) {
MS_ASSERT(op != nullptr);
MS_ASSERT(dstCNode != nullptr);
const auto &tfliteSubgraph = tfliteModel->subgraphs.front();
if (op->outputs.size() == 1) {
const auto &tensor = tfliteSubgraph->tensors.at(op->outputs.front());
std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto typePtr = TypeIdToType(GetTfliteDataType(tensor->type));
dstCNode->set_abstract(std::make_shared<abstract::AbstractTensor>(typePtr, shape_vector));
nodes.insert(std::pair(op->outputs.front(), dstCNode));
} else {
AbstractBasePtrList abstractList;
for (auto outputIdx : op->outputs) {
const auto &tensor = tfliteSubgraph->tensors.at(outputIdx);
std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto typePtr = TypeIdToType(GetTfliteDataType(tensor->type));
abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(typePtr, shape_vector));
auto tupleGetItemPrimPtr = GetTupleGetItemPrim();
if (tupleGetItemPrimPtr == nullptr) {
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
return RET_NULL_PTR;
}
auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr);
auto getItemValue = NewValueNode(MakeValue<int>(outputIdx));
std::vector<AnfNodePtr> inputs{tupleGetItemPrim, dstCNode, getItemValue};
CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs);
getItemCNode->set_fullname_with_scope(dstCNode->fullname_with_scope() + "_getitem_" + std::to_string(outputIdx));
nodes.insert(std::pair(outputIdx, getItemCNode));
}
dstCNode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList));
}
return RET_OK;
}
} // namespace mindspore::lite

+ 44
- 0
mindspore/lite/tools/converter/parser/tflite/model_parser_for_tflite.h View File

@@ -0,0 +1,44 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MODEL_PARSER_FOR_TFLITE_H
#define LITE_MODEL_PARSER_FOR_TFLITE_H

#include <string>
#include <unordered_map>
#include <memory>
#include "tools/converter/parser/tflite/tflite_model_parser.h"

namespace mindspore::lite {
class ModelParserForTflite : public TfliteModelParser {
public:
ModelParserForTflite() = default;

~ModelParserForTflite() override = default;

FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) override;

private:
std::unordered_map<int, AnfNodePtr> nodes;
std::unique_ptr<tflite::ModelT> tfliteModel;
FuncGraphPtr funcGraphPtr;
STATUS ConvertConstTensor(const tflite::TensorT *tensor, ParameterPtr parameter);
STATUS ConvertOutputTensor(const tflite::OperatorT *op, CNodePtr dstCNode);
STATUS ConvertOps();
STATUS ConvertGraphInputs();
STATUS ConvertGraphOutputs();
};
} // namespace mindspore::lite
#endif // LITE_MODEL_PARSER_FOR_TFLITE_H

+ 6
- 8
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h View File

@@ -33,8 +33,7 @@
#include "tools/common/tensor_util.h"
#include "schema/inner/model_generated.h"

namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TfliteModelParser : public ModelParser {
public:
TfliteModelParser();
@@ -42,9 +41,9 @@ class TfliteModelParser : public ModelParser {
~TfliteModelParser() override;

schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quantType = QuantType_QUANT_NONE) override;
const QuantType &quantTyp) override;

private:
protected:
std::unique_ptr<tflite::ModelT> ReadTfliteModel(const char *model_path);

STATUS CopyConstTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
@@ -64,6 +63,8 @@ class TfliteModelParser : public ModelParser {

STATUS ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph);

QuantType quantType = QuantType_QUANT_NONE;
char *tfliteModelBuf = nullptr;
std::unique_ptr<schema::MetaGraphT> ConstructMainGraph(const std::unique_ptr<tflite::ModelT> &tflite_model,
const QuantType &quant_type);

@@ -73,9 +74,6 @@ class TfliteModelParser : public ModelParser {

std::map<std::string, schema::CNodeT *> opMap;
std::map<const tflite::OperatorT *, schema::CNodeT *> tfliteOpMap;
QuantType quantType = QuantType_QUANT_NONE;
char *tfliteModelBuf = nullptr;
};
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H

+ 7
- 4
mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h View File

@@ -22,6 +22,7 @@
#include <map>
#include <memory>
#include <utility>
#include "src/ops/primitive_c.h"
#include "src/common/log_adapter.h"
#include "schema/inner/model_generated.h"
#include "schema/schema_generated.h"
@@ -30,8 +31,7 @@
#include "include/errorcode.h"
#include "tools/converter/parser/tflite/tflite_util.h"

namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TfliteNodeParser {
public:
explicit TfliteNodeParser(const std::string &node_name) : name(node_name) {}
@@ -41,6 +41,10 @@ class TfliteNodeParser {
virtual STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) = 0;
virtual STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, PrimitiveC *primitiveC) {
return RET_OK;
}

void AddOpInput(schema::CNodeT *op, TfliteTensorsInfo *tensors_info, int idx, int total, schema::Format format) {
int new_idx = tensors_info->tensorsId.size();
@@ -158,7 +162,6 @@ class TfliteNodeParser {
{tflite::TensorType_BOOL, TypeId::kNumberTypeBool},
};
};
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_NODE_PARSER_H

+ 11
- 5
mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc View File

@@ -17,10 +17,9 @@
#include "tools/converter/parser/tflite/tflite_softmax_parser.h"
#include <vector>
#include <memory>
#include <map>
#include "src/ops/softmax.h"

namespace mindspore {
namespace lite {
namespace mindspore::lite {
STATUS TfliteSoftmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
@@ -51,6 +50,13 @@ STATUS TfliteSoftmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
return RET_OK;
}

STATUS TfliteSoftmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, PrimitiveC *primitiveC) {
auto softmaxPrimitive = new SoftMax();
softmaxPrimitive->SetAxis(-1);
primitiveC = softmaxPrimitive;
return RET_OK;
}

TfliteNodeRegister g_tfliteSoftmaxParser("Softmax", new TfliteSoftmaxParser());
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite

+ 5
- 5
mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h View File

@@ -19,12 +19,11 @@

#include <memory>
#include <vector>
#include <map>
#include <unordered_map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TfliteSoftmaxParser : public TfliteNodeParser {
public:
TfliteSoftmaxParser() : TfliteNodeParser("Softmax") {}
@@ -32,8 +31,9 @@ class TfliteSoftmaxParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, const std::unique_ptr<tflite::ModelT> &tflite_model,
PrimitiveC *primitiveC) override;
};
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H

Loading…
Cancel
Save