|
|
|
@@ -81,7 +81,6 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std:: |
|
|
|
|
|
|
|
STATUS TfliteModelParser::ConvertOps() { |
|
|
|
const auto &tflite_subgraph = tflite_model_->subgraphs.front(); |
|
|
|
const auto &tflite_model_buffers = tflite_model_->buffers; |
|
|
|
NoSupportOp::GetInstance()->SetFmkType("TFLITE"); |
|
|
|
STATUS status = RET_OK; |
|
|
|
int op_idx = 0; |
|
|
|
@@ -117,6 +116,9 @@ STATUS TfliteModelParser::ConvertOps() { |
|
|
|
std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<lite::PrimitiveC>(primitiveC))}; |
|
|
|
// parse inputs |
|
|
|
for (auto input_idx : op->inputs) { |
|
|
|
if (tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED && input_idx == -1) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (input_idx < 0) { |
|
|
|
input_idx += tflite_subgraph->tensors.size(); |
|
|
|
} |
|
|
|
@@ -126,18 +128,14 @@ STATUS TfliteModelParser::ConvertOps() { |
|
|
|
continue; |
|
|
|
} |
|
|
|
// const tensor |
|
|
|
if (!tflite_model_buffers.at(input_tensor->buffer)->data.empty()) { |
|
|
|
auto parameter = func_graph_->add_parameter(); |
|
|
|
status = ConvertConstTensor(input_tensor.get(), parameter.get()); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed."; |
|
|
|
return status; |
|
|
|
} |
|
|
|
op_inputs.emplace_back(parameter); |
|
|
|
nodes_.insert(std::pair(input_idx, parameter)); |
|
|
|
continue; |
|
|
|
auto parameter = func_graph_->add_parameter(); |
|
|
|
status = ConvertConstTensor(input_tensor.get(), parameter.get()); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed."; |
|
|
|
return status; |
|
|
|
} |
|
|
|
MS_LOG(WARNING) << "tensor " << input_idx << " is neither a node output nor a weight tensor."; |
|
|
|
op_inputs.emplace_back(parameter); |
|
|
|
nodes_.insert(std::pair(input_idx, parameter)); |
|
|
|
} |
|
|
|
auto new_cnode = func_graph_->NewCNode(op_inputs); |
|
|
|
new_cnode->set_fullname_with_scope(op_name); |
|
|
|
@@ -268,6 +266,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs() { |
|
|
|
auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); |
|
|
|
make_tuple_inputs.emplace_back(make_tuple_prim); |
|
|
|
for (auto outputNode : tflite_subgraph->outputs) { |
|
|
|
outputNode = outputNode < 0 ? outputNode + tflite_subgraph->tensors.size() : outputNode; |
|
|
|
auto cnode = nodes_.at(outputNode); |
|
|
|
if (nullptr == cnode) { |
|
|
|
MS_LOG(ERROR) << "Can't find input node."; |
|
|
|
@@ -296,9 +295,12 @@ STATUS TfliteModelParser::ConvertGraphOutputs() { |
|
|
|
MS_LOG(ERROR) << "GetReturnPrim return nullptr"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
int outputNode = tflite_subgraph->outputs.front() < 0 |
|
|
|
? static_cast<int>(tflite_subgraph->outputs.front() + tflite_subgraph->tensors.size()) |
|
|
|
: static_cast<int>(tflite_subgraph->outputs.front()); |
|
|
|
auto valueNode = NewValueNode(returnPrim); |
|
|
|
std::vector<AnfNodePtr> op_inputs{valueNode}; |
|
|
|
auto cnode = nodes_.at(tflite_subgraph->outputs.front()); |
|
|
|
auto cnode = nodes_.at(outputNode); |
|
|
|
if (nullptr == cnode) { |
|
|
|
MS_LOG(ERROR) << "Can't find input node."; |
|
|
|
return RET_NOT_FIND_OP; |
|
|
|
@@ -345,8 +347,8 @@ STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Para |
|
|
|
} |
|
|
|
std::memcpy(tensor_data, data.data(), size); |
|
|
|
param_value->SetTensorData(tensor_data, size); |
|
|
|
parameter->set_default_param(param_value); |
|
|
|
} |
|
|
|
parameter->set_default_param(param_value); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
|