| @@ -226,6 +226,7 @@ union PrimitiveType { | |||||
| InstanceNorm, | InstanceNorm, | ||||
| Identity, | Identity, | ||||
| LayerNorm, | LayerNorm, | ||||
| While, | |||||
| } | } | ||||
| enum QuantType: int { | enum QuantType: int { | ||||
| @@ -1103,3 +1103,8 @@ table LayerNorm { | |||||
| elementwiseAffine : bool; | elementwiseAffine : bool; | ||||
| } | } | ||||
| table While { | |||||
| condSubgraphIndex : int; | |||||
| bodySubgraphIndex : int; | |||||
| } | |||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/while.h" | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "src/ops/populate/populate_register.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| typedef struct WhileParemeter { | |||||
| OpParameter op_parameter_; | |||||
| int body_subgraph_index; | |||||
| int cond_subgraph_index; | |||||
| } WhileParemeter; | |||||
| OpParameter *PopulateWhileParemeter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| WhileParemeter *while_paremeter = reinterpret_cast<WhileParemeter *>(malloc(sizeof(WhileParemeter))); | |||||
| if (while_paremeter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc WhileParemeter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(while_paremeter, 0, sizeof(WhileParemeter)); | |||||
| auto param = reinterpret_cast<mindspore::lite::While *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| while_paremeter->op_parameter_.type_ = primitive->Type(); | |||||
| while_paremeter->body_subgraph_index = param->GetBodySubgraphIndex(); | |||||
| while_paremeter->cond_subgraph_index = param->GetCondSubgraphIndex(); | |||||
| return reinterpret_cast<OpParameter *>(while_paremeter); | |||||
| } | |||||
| Registry WhileParemeterRegistry(schema::PrimitiveType_While, PopulateWhileParemeter); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -144,6 +144,7 @@ | |||||
| #include "src/ops/mfcc.h" | #include "src/ops/mfcc.h" | ||||
| #include "src/ops/identity.h" | #include "src/ops/identity.h" | ||||
| #include "src/ops/instance_norm.h" | #include "src/ops/instance_norm.h" | ||||
| #include "src/ops/while.h" | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| #include "src/ops/neg_grad.h" | #include "src/ops/neg_grad.h" | ||||
| @@ -499,6 +500,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<Maximum>(prim, inputs, quantType); | return NewPrimitiveC<Maximum>(prim, inputs, quantType); | ||||
| } else if (op_type == "Split") { | } else if (op_type == "Split") { | ||||
| return NewPrimitiveC<Split>(prim, inputs, quantType); | return NewPrimitiveC<Split>(prim, inputs, quantType); | ||||
| } else if (op_type == "While") { | |||||
| return NewPrimitiveC<While>(prim, inputs, quantType); | |||||
| } else if (op_type == "OneHot") { | } else if (op_type == "OneHot") { | ||||
| return NewPrimitiveC<OneHot>(prim, inputs, quantType); | return NewPrimitiveC<OneHot>(prim, inputs, quantType); | ||||
| @@ -793,6 +796,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new Mfcc(primitive); | return new Mfcc(primitive); | ||||
| case schema::PrimitiveType_InstanceNorm: | case schema::PrimitiveType_InstanceNorm: | ||||
| return new InstanceNorm(primitive); | return new InstanceNorm(primitive); | ||||
| case schema::PrimitiveType_While: | |||||
| return new While(primitive); | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| case schema::PrimitiveType_ActivationGrad: | case schema::PrimitiveType_ActivationGrad: | ||||
| @@ -0,0 +1,107 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/while.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| void While::SetCondSubgraphIndex(const int cond_subgraph_index) { | |||||
| this->primitive_->value.AsWhile()->condSubgraphIndex = cond_subgraph_index; | |||||
| } | |||||
| void While::SetBodySubgraphIndex(const int body_subgraph_index) { | |||||
| this->primitive_->value.AsWhile()->bodySubgraphIndex = body_subgraph_index; | |||||
| } | |||||
| int While::GetCondSubgraphIndex() const { return this->primitive_->value.AsWhile()->condSubgraphIndex; } | |||||
| int While::GetBodySubgraphIndex() const { return this->primitive_->value.AsWhile()->bodySubgraphIndex; } | |||||
| int While::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_While; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_While) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::WhileT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->bodySubgraphIndex = GetValue<bool>(prim.GetAttr("body_subgraph_index")); | |||||
| attr->condSubgraphIndex = GetValue<bool>(prim.GetAttr("cond_subgraph_index")); | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int While::GetCondSubgraphIndex() const { return this->primitive_->value_as_While()->condSubgraphIndex(); } | |||||
| int While::GetBodySubgraphIndex() const { return this->primitive_->value_as_While()->bodySubgraphIndex(); } | |||||
| int While::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_While(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_While return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto cond_subgraph_index = attr->condSubgraphIndex(); | |||||
| auto body_subgraph_index = attr->bodySubgraphIndex(); | |||||
| auto val_offset = schema::CreateWhile(*fbb, body_subgraph_index, cond_subgraph_index); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_While, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| PrimitiveC *WhileCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<While>(primitive); } | |||||
| Registry WhileRegistry(schema::PrimitiveType_While, WhileCreator); | |||||
| #endif | |||||
| int While::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| if (inputs_.size() != outputs_.size()) { | |||||
| MS_LOG(ERROR) << "The number of inputs and outputs varies"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| for (size_t i = 0; i < inputs_.size(); i++) { | |||||
| outputs_[i]->set_data_type(inputs_[i]->data_type()); | |||||
| outputs_[i]->SetFormat(inputs_[i]->GetFormat()); | |||||
| outputs_[i]->set_shape(inputs_[i]->shape()); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_WHILE_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_WHILE_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include <memory> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class While : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(While, PrimitiveC); | |||||
| While() = default; | |||||
| explicit While(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| void SetCondSubgraphIndex(const int cond_subgraph_index); | |||||
| void SetBodySubgraphIndex(const int body_subgraph_index); | |||||
| #else | |||||
| While() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| int GetCondSubgraphIndex() const; | |||||
| int GetBodySubgraphIndex() const; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_WHERE_H_ | |||||
| @@ -24,7 +24,8 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -75,10 +76,8 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| op->primitive->value.type = schema::PrimitiveType_Activation; | op->primitive->value.type = schema::PrimitiveType_Activation; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteActivationParser : public TfliteNodeParser { | |||||
| TfliteActivationParser() : TfliteNodeParser("node_name") {} | TfliteActivationParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteReluParser : public TfliteActivationParser { | class TfliteReluParser : public TfliteActivationParser { | ||||
| @@ -23,7 +23,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteAddNParser"; | MS_LOG(DEBUG) << "parse TfliteAddNParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -41,16 +42,14 @@ STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->N = tflite_model->subgraphs[0]->tensors.size() - 1; | |||||
| attr->N = tflite_subgraph->tensors.size() - 1; | |||||
| op->primitive->value.type = schema::PrimitiveType_AddN; | op->primitive->value.type = schema::PrimitiveType_AddN; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteAddNParser : public TfliteNodeParser { | |||||
| TfliteAddNParser() : TfliteNodeParser("AddN") {} | TfliteAddNParser() : TfliteNodeParser("AddN") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; | MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -47,7 +48,7 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| // get axis attr | // get axis attr | ||||
| auto axis_idx = tflite_op->inputs[1]; | auto axis_idx = tflite_op->inputs[1]; | ||||
| auto buffer_idx = tflite_model->subgraphs[0]->tensors[axis_idx]->buffer; | |||||
| auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer; | |||||
| auto &buf_data = tflite_model->buffers[buffer_idx]; | auto &buf_data = tflite_model->buffers[buffer_idx]; | ||||
| if (buf_data == nullptr) { | if (buf_data == nullptr) { | ||||
| MS_LOG(ERROR) << "the buf data is null"; | MS_LOG(ERROR) << "the buf data is null"; | ||||
| @@ -63,10 +64,8 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| op->primitive->value.type = schema::PrimitiveType_ArgMax; | op->primitive->value.type = schema::PrimitiveType_ArgMax; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteArgmaxParser : public TfliteNodeParser { | |||||
| TfliteArgmaxParser() : TfliteNodeParser("Argmax") {} | TfliteArgmaxParser() : TfliteNodeParser("Argmax") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteArgminParser"; | MS_LOG(DEBUG) << "parse TfliteArgminParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -47,7 +48,7 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| // get axis attr | // get axis attr | ||||
| auto axis_idx = tflite_op->inputs[1]; | auto axis_idx = tflite_op->inputs[1]; | ||||
| auto buffer_idx = tflite_model->subgraphs[0]->tensors[axis_idx]->buffer; | |||||
| auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer; | |||||
| auto &buf_data = tflite_model->buffers[buffer_idx]; | auto &buf_data = tflite_model->buffers[buffer_idx]; | ||||
| if (buf_data == nullptr) { | if (buf_data == nullptr) { | ||||
| MS_LOG(ERROR) << "the buf data is null"; | MS_LOG(ERROR) << "the buf data is null"; | ||||
| @@ -63,10 +64,8 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| op->primitive->value.type = schema::PrimitiveType_ArgMin; | op->primitive->value.type = schema::PrimitiveType_ArgMin; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteArgminParser : public TfliteNodeParser { | |||||
| TfliteArgminParser() : TfliteNodeParser("Argmin") {} | TfliteArgminParser() : TfliteNodeParser("Argmin") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -24,7 +24,8 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -168,17 +169,16 @@ STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| // set input | // set input | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -305,16 +305,15 @@ STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| } | } | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -385,11 +384,9 @@ STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| } | } | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser { | |||||
| TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {} | TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteAddParser : public TfliteDoubleInputOpParser { | class TfliteAddParser : public TfliteDoubleInputOpParser { | ||||
| @@ -93,7 +94,8 @@ class TfliteSingleInputOpParser : public TfliteNodeParser { | |||||
| TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} | TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteAbsParser : public TfliteSingleInputOpParser { | class TfliteAbsParser : public TfliteSingleInputOpParser { | ||||
| @@ -161,7 +163,8 @@ class TfliteCompareOpParser : public TfliteNodeParser { | |||||
| TfliteCompareOpParser() : TfliteNodeParser("node_name") {} | TfliteCompareOpParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteEqualParser : public TfliteCompareOpParser { | class TfliteEqualParser : public TfliteCompareOpParser { | ||||
| @@ -25,7 +25,8 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -51,12 +52,11 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, | |||||
| attr->blockShape)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->blockShape)) { | |||||
| MS_LOG(ERROR) << "get batchToSpace -> blockShape failed"; | MS_LOG(ERROR) << "get batchToSpace -> blockShape failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->crops)) { | |||||
| if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->crops)) { | |||||
| MS_LOG(ERROR) << "get batchToSpace -> crops failed"; | MS_LOG(ERROR) << "get batchToSpace -> crops failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -64,10 +64,8 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| op->primitive->value.type = schema::PrimitiveType_BatchToSpace; | op->primitive->value.type = schema::PrimitiveType_BatchToSpace; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser { | |||||
| TfliteBatchToSpaceParser() : TfliteNodeParser("BatchToSpace") {} | TfliteBatchToSpaceParser() : TfliteNodeParser("BatchToSpace") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { | class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { | ||||
| @@ -24,7 +24,8 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteBroadcastToParser"; | MS_LOG(DEBUG) << "parse TfliteBroadcastToParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -42,8 +43,7 @@ STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, | |||||
| attr->dst_shape)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dst_shape)) { | |||||
| MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed"; | MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -51,10 +51,8 @@ STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| op->primitive->value.type = schema::PrimitiveType_BroadcastTo; | op->primitive->value.type = schema::PrimitiveType_BroadcastTo; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteBroadcastToParser : public TfliteNodeParser { | |||||
| TfliteBroadcastToParser() : TfliteNodeParser("BroadcastTo") {} | TfliteBroadcastToParser() : TfliteNodeParser("BroadcastTo") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteCastParser"; | MS_LOG(DEBUG) << "parse TfliteCastParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -40,13 +41,13 @@ STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const auto &in_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; | |||||
| const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; | |||||
| if (in_tensor == nullptr) { | if (in_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "tensor is null"; | MS_LOG(ERROR) << "tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->srcT = GetTfliteDataType(in_tensor->type); | attr->srcT = GetTfliteDataType(in_tensor->type); | ||||
| const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]]; | |||||
| const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]]; | |||||
| if (out_tensor == nullptr) { | if (out_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "tensor is null"; | MS_LOG(ERROR) << "tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -56,10 +57,8 @@ STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| op->primitive->value.type = schema::PrimitiveType_Cast; | op->primitive->value.type = schema::PrimitiveType_Cast; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteCastParser : public TfliteNodeParser { | |||||
| TfliteCastParser() : TfliteNodeParser("Cast") {} | TfliteCastParser() : TfliteNodeParser("Cast") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteConcatParser"; | MS_LOG(DEBUG) << "parse TfliteConcatParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -52,11 +53,9 @@ STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteConcatParser : public TfliteNodeParser { | |||||
| TfliteConcatParser() : TfliteNodeParser("Concat") {} | TfliteConcatParser() : TfliteNodeParser("Concat") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteConvParser"; | MS_LOG(DEBUG) << "parse TfliteConvParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -57,7 +58,7 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| // get the conv op weight tensor | // get the conv op weight tensor | ||||
| auto weight_index = tflite_op->inputs[1]; | auto weight_index = tflite_op->inputs[1]; | ||||
| const auto &weight_tensor = tflite_model->subgraphs[0]->tensors[weight_index]; | |||||
| const auto &weight_tensor = tflite_subgraph->tensors[weight_index]; | |||||
| if (weight_tensor == nullptr) { | if (weight_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "the weight tensor is null"; | MS_LOG(ERROR) << "the weight tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -70,7 +71,7 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| // calculate pad params | // calculate pad params | ||||
| auto data_index = tflite_op->inputs[0]; | auto data_index = tflite_op->inputs[0]; | ||||
| const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | |||||
| const auto &data_tensor = tflite_subgraph->tensors[data_index]; | |||||
| std::vector<int> params; | std::vector<int> params; | ||||
| int status = | int status = | ||||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | ||||
| @@ -87,14 +88,10 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | op->primitive->value.type = schema::PrimitiveType_Conv2D; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_KHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteConvParser : public TfliteNodeParser { | |||||
| TfliteConvParser() : TfliteNodeParser("Conv2D") {} | TfliteConvParser() : TfliteNodeParser("Conv2D") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -139,14 +139,15 @@ STATUS TfliteCustomParser::ExtractFeatures(const std::vector<uint8_t> &custom_at | |||||
| STATUS TfliteCustomParser::Rfft(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | STATUS TfliteCustomParser::Rfft(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph) { | |||||
| std::unique_ptr<schema::RfftT> attr = std::make_unique<schema::RfftT>(); | std::unique_ptr<schema::RfftT> attr = std::make_unique<schema::RfftT>(); | ||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<int> fft_length; | std::vector<int> fft_length; | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, fft_length)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, fft_length)) { | |||||
| MS_LOG(ERROR) << "rfft -> fftLength get failed"; | MS_LOG(ERROR) << "rfft -> fftLength get failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -181,7 +182,8 @@ STATUS TfliteCustomParser::FftImag(const std::vector<uint8_t> &custom_attr, sche | |||||
| } | } | ||||
| STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteCustomParser"; | MS_LOG(DEBUG) << "parse TfliteCustomParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -209,7 +211,7 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| } else if (custom_type == "Mfcc") { | } else if (custom_type == "Mfcc") { | ||||
| status = Mfcc(custom_attr, op, tflite_op); | status = Mfcc(custom_attr, op, tflite_op); | ||||
| } else if (custom_type == "FlexRFFT") { | } else if (custom_type == "FlexRFFT") { | ||||
| status = Rfft(custom_attr, op, tflite_op, tflite_model); | |||||
| status = Rfft(custom_attr, op, tflite_op, tflite_model, tflite_subgraph); | |||||
| } else if (custom_type == "FlexReal") { | } else if (custom_type == "FlexReal") { | ||||
| status = FftReal(custom_attr, op, tflite_op); | status = FftReal(custom_attr, op, tflite_op); | ||||
| } else if (custom_type == "FlexImag") { | } else if (custom_type == "FlexImag") { | ||||
| @@ -222,12 +224,10 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| return status; | return status; | ||||
| } | } | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { | for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| for (size_t i = 0; i < tflite_op->outputs.size(); ++i) { | for (size_t i = 0; i < tflite_op->outputs.size(); ++i) { | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| return status; | return status; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteCustomParser : public TfliteNodeParser { | |||||
| TfliteCustomParser() : TfliteNodeParser("Custom") {} | TfliteCustomParser() : TfliteNodeParser("Custom") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| STATUS DetectPostProcess(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | STATUS DetectPostProcess(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | const std::unique_ptr<tflite::OperatorT> &tflite_op); | ||||
| @@ -51,7 +52,8 @@ class TfliteCustomParser : public TfliteNodeParser { | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | const std::unique_ptr<tflite::OperatorT> &tflite_op); | ||||
| STATUS Rfft(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | STATUS Rfft(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, const std::unique_ptr<tflite::ModelT> &tflite_model); | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph); | |||||
| STATUS FftReal(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | STATUS FftReal(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | const std::unique_ptr<tflite::OperatorT> &tflite_op); | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; | MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -58,7 +59,7 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| // get the conv op weight tensor | // get the conv op weight tensor | ||||
| auto weight_index = tflite_op->inputs[1]; | auto weight_index = tflite_op->inputs[1]; | ||||
| const auto &weight_tensor = tflite_model->subgraphs[0]->tensors[weight_index]; | |||||
| const auto &weight_tensor = tflite_subgraph->tensors[weight_index]; | |||||
| if (weight_tensor == nullptr) { | if (weight_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "the weight tensor is null"; | MS_LOG(ERROR) << "the weight tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -71,7 +72,7 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| // calculate pad params | // calculate pad params | ||||
| auto data_index = tflite_op->inputs[2]; | auto data_index = tflite_op->inputs[2]; | ||||
| const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | |||||
| const auto &data_tensor = tflite_subgraph->tensors[data_index]; | |||||
| std::vector<int> params; | std::vector<int> params; | ||||
| int status = | int status = | ||||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | ||||
| @@ -88,12 +89,9 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| op->primitive->value.type = schema::PrimitiveType_DeConv2D; | op->primitive->value.type = schema::PrimitiveType_DeConv2D; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_KHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteDeConvParser : public TfliteNodeParser { | |||||
| TfliteDeConvParser() : TfliteNodeParser("DeConv2D") {} | TfliteDeConvParser() : TfliteNodeParser("DeConv2D") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -24,7 +24,8 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteDepthToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteDepthToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser"; | MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| @@ -54,10 +55,8 @@ STATUS TfliteDepthToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| op->primitive->value.type = schema::PrimitiveType_DepthToSpace; | op->primitive->value.type = schema::PrimitiveType_DepthToSpace; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser { | |||||
| TfliteDepthToSpaceParser() : TfliteNodeParser("DepthToSpace") {} | TfliteDepthToSpaceParser() : TfliteNodeParser("DepthToSpace") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,7 +23,9 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; | MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -58,7 +60,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| // get the data tensor | // get the data tensor | ||||
| auto data_index = tflite_op->inputs[1]; | auto data_index = tflite_op->inputs[1]; | ||||
| const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | |||||
| const auto &data_tensor = tflite_subgraph->tensors[data_index]; | |||||
| if (data_tensor == nullptr) { | if (data_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "the data tensor is null"; | MS_LOG(ERROR) << "the data tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -68,7 +70,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| // get the weight tensor | // get the weight tensor | ||||
| auto weight_index = tflite_op->inputs[1]; | auto weight_index = tflite_op->inputs[1]; | ||||
| const auto &weight_tensor = tflite_model->subgraphs[0]->tensors[weight_index]; | |||||
| const auto &weight_tensor = tflite_subgraph->tensors[weight_index]; | |||||
| if (weight_tensor == nullptr) { | if (weight_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "the weight tensor is null"; | MS_LOG(ERROR) << "the weight tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -94,14 +96,10 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_KHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser { | |||||
| TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} | TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,8 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteDequantizeNParser"; | MS_LOG(DEBUG) << "parse TfliteDequantizeNParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -34,12 +35,12 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const auto &in_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; | |||||
| const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; | |||||
| if (in_tensor == nullptr) { | if (in_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "input tensor is null"; | MS_LOG(ERROR) << "input tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]]; | |||||
| const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]]; | |||||
| if (out_tensor == nullptr) { | if (out_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "output tensor is null"; | MS_LOG(ERROR) << "output tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -68,10 +69,8 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| op->primitive->value.type = schema::PrimitiveType_Cast; | op->primitive->value.type = schema::PrimitiveType_Cast; | ||||
| } | } | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,7 +29,8 @@ class TfliteDequantizeParser : public TfliteNodeParser { | |||||
| TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {} | TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,7 +23,8 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteExpandDimsParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteExpandDimsParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteExpandDimsParser"; | MS_LOG(DEBUG) << "parse TfliteExpandDimsParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -41,17 +42,15 @@ STATUS TfliteExpandDimsParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<int> dims; | std::vector<int> dims; | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, dims)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, dims)) { | |||||
| MS_LOG(ERROR) << "get expand_dims -> dim failed"; | MS_LOG(ERROR) << "get expand_dims -> dim failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->dim = dims[0]; | attr->dim = dims[0]; | ||||
| op->primitive->value.type = schema::PrimitiveType_ExpandDims; | op->primitive->value.type = schema::PrimitiveType_ExpandDims; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser()); | TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser()); | ||||
| @@ -30,7 +30,8 @@ class TfliteExpandDimsParser : public TfliteNodeParser { | |||||
| TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {} | TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteFillParser"; | MS_LOG(DEBUG) << "parse TfliteFillParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -41,7 +42,7 @@ STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| } | } | ||||
| if (tflite_op->inputs.size() > 1) { | if (tflite_op->inputs.size() > 1) { | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->dims)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dims)) { | |||||
| MS_LOG(ERROR) << "get fill -> dims failed"; | MS_LOG(ERROR) << "get fill -> dims failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -50,10 +51,8 @@ STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| op->primitive->value.type = schema::PrimitiveType_Fill; | op->primitive->value.type = schema::PrimitiveType_Fill; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteFillParser : public TfliteNodeParser { | |||||
| TfliteFillParser() : TfliteNodeParser("Fill") {} | TfliteFillParser() : TfliteNodeParser("Fill") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,7 +23,9 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteFullyConnectedParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteFullyConnectedParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; | MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -57,16 +59,12 @@ STATUS TfliteFullyConnectedParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| op->primitive->value.type = schema::PrimitiveType_FullConnection; | op->primitive->value.type = schema::PrimitiveType_FullConnection; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_KHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC); | |||||
| if (hasBias) { | if (hasBias) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteFullyConnectedParser : public TfliteNodeParser { | |||||
| TfliteFullyConnectedParser() : TfliteNodeParser("FullyConnected") {} | TfliteFullyConnectedParser() : TfliteNodeParser("FullyConnected") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteFakeQuantParser : public TfliteFullyConnectedParser { | class TfliteFakeQuantParser : public TfliteFullyConnectedParser { | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteGatherNdParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteGatherNdParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteGatherNdParser"; | MS_LOG(DEBUG) << "parse TfliteGatherNdParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -46,11 +47,9 @@ STATUS TfliteGatherNdParser::Parse(TfliteTensorsInfo *tensors_info, const std::u | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteGatherNdParser : public TfliteNodeParser { | |||||
| TfliteGatherNdParser() : TfliteNodeParser("GatherND") {} | TfliteGatherNdParser() : TfliteNodeParser("GatherND") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteGatherParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteGatherParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteGatherParser"; | MS_LOG(DEBUG) << "parse TfliteGatherParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -52,11 +53,9 @@ STATUS TfliteGatherParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteGatherParser : public TfliteNodeParser { | |||||
| TfliteGatherParser() : TfliteNodeParser("Gather") {} | TfliteGatherParser() : TfliteNodeParser("Gather") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,7 +23,9 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteHashtableLookupParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteHashtableLookupParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteHashtableLookupParser"; | MS_LOG(DEBUG) << "parse TfliteHashtableLookupParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -44,12 +46,10 @@ STATUS TfliteHashtableLookupParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| op->primitive->value.type = schema::PrimitiveType_HashtableLookup; | op->primitive->value.type = schema::PrimitiveType_HashtableLookup; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { | for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| for (size_t i = 0; i < tflite_op->outputs.size(); ++i) { | for (size_t i = 0; i < tflite_op->outputs.size(); ++i) { | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteHashtableLookupParser : public TfliteNodeParser { | |||||
| TfliteHashtableLookupParser() : TfliteNodeParser("HashtableLookup") {} | TfliteHashtableLookupParser() : TfliteNodeParser("HashtableLookup") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,7 +23,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteL2NormParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteL2NormParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteL2NormParser"; | MS_LOG(DEBUG) << "parse TfliteL2NormParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -49,10 +50,8 @@ STATUS TfliteL2NormParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| // set input and output | // set input and output | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteL2NormParser : public TfliteNodeParser { | |||||
| TfliteL2NormParser() : TfliteNodeParser("L2_NORMALIZATION") {} | TfliteL2NormParser() : TfliteNodeParser("L2_NORMALIZATION") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,7 +23,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteLogicalParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteLogicalParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -67,11 +68,9 @@ STATUS TfliteLogicalParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| } | } | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteLogicalParser : public TfliteNodeParser { | |||||
| TfliteLogicalParser() : TfliteNodeParser("node_name") {} | TfliteLogicalParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteLogicalAndParser : public TfliteLogicalParser { | class TfliteLogicalAndParser : public TfliteLogicalParser { | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteLRNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteLRNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLRNParser"; | MS_LOG(DEBUG) << "parse TfliteLRNParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -53,10 +54,8 @@ STATUS TfliteLRNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique | |||||
| op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; | op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteLRNParser : public TfliteNodeParser { | |||||
| TfliteLRNParser() : TfliteNodeParser("LocalResponseNorm") {} | TfliteLRNParser() : TfliteNodeParser("LocalResponseNorm") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,7 +23,8 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteLshProjectionParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteLshProjectionParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLshProjectionParser"; | MS_LOG(DEBUG) << "parse TfliteLshProjectionParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -56,11 +57,9 @@ STATUS TfliteLshProjectionParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { | for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteLshProjectionParser : public TfliteNodeParser { | |||||
| TfliteLshProjectionParser() : TfliteNodeParser("LshProjection") {} | TfliteLshProjectionParser() : TfliteNodeParser("LshProjection") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -116,7 +116,7 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (status == RET_OK) { | if (status == RET_OK) { | ||||
| status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, op.get()); | |||||
| status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, tflite_subgraph, op.get()); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| if (status == RET_NOT_FIND_OP) { | if (status == RET_NOT_FIND_OP) { | ||||
| op_type = | op_type = | ||||
| @@ -337,18 +337,10 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph) | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) { | |||||
| // load graph | |||||
| auto tflite_model = ReadTfliteModel(model_file.c_str()); | |||||
| if (tflite_model == nullptr) { | |||||
| MS_LOG(ERROR) << "read tflite model failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | |||||
| return nullptr; | |||||
| } | |||||
| if (tflite_model->subgraphs.size() != 1) { | |||||
| MS_LOG(ERROR) << "read tflite model subgraphs failed"; | |||||
| std::unique_ptr<schema::MetaGraphT> TfliteModelParser::ConstructMainGraph( | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, const QuantType &quant_type) { | |||||
| if (tflite_model->subgraphs.size() < 1) { | |||||
| MS_LOG(ERROR) << "read tflite model main subgraphs failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -394,7 +386,28 @@ schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return meta_graph.release(); | |||||
| return meta_graph; | |||||
| } | |||||
| schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) { | |||||
| // load graph | |||||
| auto tflite_model = ReadTfliteModel(model_file.c_str()); | |||||
| if (tflite_model == nullptr) { | |||||
| MS_LOG(ERROR) << "read tflite model failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | |||||
| return nullptr; | |||||
| } | |||||
| // construct main_meta_graph | |||||
| auto main_meta_graph = ConstructMainGraph(tflite_model, quant_type); | |||||
| if (main_meta_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "ConstructMainGraph failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | |||||
| return nullptr; | |||||
| } | |||||
| return main_meta_graph.release(); | |||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -64,6 +64,9 @@ class TfliteModelParser : public ModelParser { | |||||
| STATUS ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph); | STATUS ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph); | ||||
| std::unique_ptr<schema::MetaGraphT> ConstructMainGraph(const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const QuantType &quant_type); | |||||
| private: | private: | ||||
| TfliteTensorsInfo tensorsInfo; | TfliteTensorsInfo tensorsInfo; | ||||
| std::vector<schema::TensorT *> tensors; | std::vector<schema::TensorT *> tensors; | ||||
| @@ -39,7 +39,8 @@ class TfliteNodeParser { | |||||
| virtual ~TfliteNodeParser() = default; | virtual ~TfliteNodeParser() = default; | ||||
| virtual STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | virtual STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) = 0; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) = 0; | |||||
| void AddOpInput(schema::CNodeT *op, TfliteTensorsInfo *tensors_info, int idx, int total, schema::Format format) { | void AddOpInput(schema::CNodeT *op, TfliteTensorsInfo *tensors_info, int idx, int total, schema::Format format) { | ||||
| int new_idx = tensors_info->tensorsId.size(); | int new_idx = tensors_info->tensorsId.size(); | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteOneHotParser"; | MS_LOG(DEBUG) << "parse TfliteOneHotParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -46,7 +47,7 @@ STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| auto axis = tflite_attr->axis; | auto axis = tflite_attr->axis; | ||||
| const auto &tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; | |||||
| const auto &tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; | |||||
| if (tensor == nullptr) { | if (tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "tensor is null"; | MS_LOG(ERROR) << "tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -57,11 +58,9 @@ STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteOneHotParser : public TfliteNodeParser { | |||||
| TfliteOneHotParser() : TfliteNodeParser("OneHot") {} | TfliteOneHotParser() : TfliteNodeParser("OneHot") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,7 +23,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TflitePadParser"; | MS_LOG(DEBUG) << "parse TflitePadParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -51,8 +52,7 @@ STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique | |||||
| } | } | ||||
| attr->paddingMode = schema::PaddingMode_CONSTANT; | attr->paddingMode = schema::PaddingMode_CONSTANT; | ||||
| attr->constantValue = 0.0f; | attr->constantValue = 0.0f; | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, | |||||
| attr->paddings)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->paddings)) { | |||||
| MS_LOG(ERROR) << "get pad -> paddings failed"; | MS_LOG(ERROR) << "get pad -> paddings failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -81,14 +81,11 @@ STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique | |||||
| op->primitive->value.type = schema::PrimitiveType_Pad; | op->primitive->value.type = schema::PrimitiveType_Pad; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| if (std::strcmp(node_name, "MirrorPad") == 0) { | if (std::strcmp(node_name, "MirrorPad") == 0) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TflitePadParser : public TfliteNodeParser { | |||||
| TflitePadParser() : TfliteNodeParser("Pad") {} | TflitePadParser() : TfliteNodeParser("Pad") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,7 +23,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -69,7 +70,7 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| // calculate pad params | // calculate pad params | ||||
| auto data_index = tflite_op->inputs[0]; | auto data_index = tflite_op->inputs[0]; | ||||
| const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | |||||
| const auto &data_tensor = tflite_subgraph->tensors[data_index]; | |||||
| std::vector<int> params; | std::vector<int> params; | ||||
| int status = | int status = | ||||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms); | getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms); | ||||
| @@ -86,10 +87,8 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| op->primitive->value.type = schema::PrimitiveType_Pooling; | op->primitive->value.type = schema::PrimitiveType_Pooling; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TflitePoolingParser : public TfliteNodeParser { | |||||
| TflitePoolingParser() : TfliteNodeParser("node_name") {} | TflitePoolingParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteMeanPoolingParser : public TflitePoolingParser { | class TfliteMeanPoolingParser : public TflitePoolingParser { | ||||
| @@ -23,7 +23,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TflitePReLUParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TflitePReLUParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TflitePReLUParser"; | MS_LOG(DEBUG) << "parse TflitePReLUParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -44,12 +45,9 @@ STATUS TflitePReLUParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq | |||||
| op->primitive->value.type = schema::PrimitiveType_PReLU; | op->primitive->value.type = schema::PrimitiveType_PReLU; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TflitePReLUParser : public TfliteNodeParser { | |||||
| TflitePReLUParser() : TfliteNodeParser("PRELU") {} | TflitePReLUParser() : TfliteNodeParser("PRELU") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,7 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteQuantizeNParser"; | MS_LOG(DEBUG) << "parse TfliteQuantizeNParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -33,12 +34,12 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const auto &in_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; | |||||
| const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; | |||||
| if (in_tensor == nullptr) { | if (in_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "input tensor is null"; | MS_LOG(ERROR) << "input tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]]; | |||||
| const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]]; | |||||
| if (out_tensor == nullptr) { | if (out_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "output tensor is null"; | MS_LOG(ERROR) << "output tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -67,10 +68,8 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| } | } | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,7 +29,8 @@ class TfliteQuantizeParser : public TfliteNodeParser { | |||||
| TfliteQuantizeParser() : TfliteNodeParser("Quantize") {} | TfliteQuantizeParser() : TfliteNodeParser("Quantize") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteRangeParser"; | MS_LOG(DEBUG) << "parse TfliteRangeParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -43,12 +44,12 @@ STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq | |||||
| attr->dType = 0; | attr->dType = 0; | ||||
| std::vector<int> limit; | std::vector<int> limit; | ||||
| std::vector<int> delta; | std::vector<int> delta; | ||||
| int status = GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, limit); | |||||
| int status = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, limit); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | if (status != RET_OK && status != RET_NO_CHANGE) { | ||||
| MS_LOG(ERROR) << "range -> limit get failed"; | MS_LOG(ERROR) << "range -> limit get failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } else if (status == RET_OK) { | } else if (status == RET_OK) { | ||||
| status = GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, delta); | |||||
| status = GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, delta); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | if (status != RET_OK && status != RET_NO_CHANGE) { | ||||
| MS_LOG(ERROR) << "stridedSlice -> end get failed"; | MS_LOG(ERROR) << "stridedSlice -> end get failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -63,11 +64,9 @@ STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq | |||||
| int input_num = status == RET_OK ? 1 : 3; | int input_num = status == RET_OK ? 1 : 3; | ||||
| for (int i = 0; i < input_num; ++i) { | for (int i = 0; i < input_num; ++i) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteRangeParser : public TfliteNodeParser { | |||||
| TfliteRangeParser() : TfliteNodeParser("Range") {} | TfliteRangeParser() : TfliteNodeParser("Range") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteRankParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteRankParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteRankParser"; | MS_LOG(DEBUG) << "parse TfliteRankParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -43,10 +44,8 @@ STATUS TfliteRankParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| op->primitive->value.type = schema::PrimitiveType_Rank; | op->primitive->value.type = schema::PrimitiveType_Rank; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteRankParser : public TfliteNodeParser { | |||||
| TfliteRankParser() : TfliteNodeParser("Rank") {} | TfliteRankParser() : TfliteNodeParser("Rank") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,7 +23,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -72,7 +73,7 @@ STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| return RET_NOT_SUPPORT; | return RET_NOT_SUPPORT; | ||||
| } | } | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->axes)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->axes)) { | |||||
| MS_LOG(ERROR) << "get reduce -> axes failed"; | MS_LOG(ERROR) << "get reduce -> axes failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -80,10 +81,8 @@ STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| op->primitive->value.type = schema::PrimitiveType_Reduce; | op->primitive->value.type = schema::PrimitiveType_Reduce; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteReduceParser : public TfliteNodeParser { | |||||
| TfliteReduceParser() : TfliteNodeParser("node_name") {} | TfliteReduceParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteReduceMaxParser : public TfliteReduceParser { | class TfliteReduceMaxParser : public TfliteReduceParser { | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteReshapeParser"; | MS_LOG(DEBUG) << "parse TfliteReshapeParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -47,7 +48,7 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto shape_tensor_index = tflite_op->inputs[1]; | auto shape_tensor_index = tflite_op->inputs[1]; | ||||
| const auto &shape_tensor = tflite_model->subgraphs[0]->tensors[shape_tensor_index]; | |||||
| const auto &shape_tensor = tflite_subgraph->tensors[shape_tensor_index]; | |||||
| if (shape_tensor == nullptr) { | if (shape_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "shape_tensor is null"; | MS_LOG(ERROR) << "shape_tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -58,8 +59,7 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (!buf_data->data.empty()) { | if (!buf_data->data.empty()) { | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, | |||||
| attr->shape)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->shape)) { | |||||
| MS_LOG(ERROR) << "get reshape -> shape failed"; | MS_LOG(ERROR) << "get reshape -> shape failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -76,11 +76,9 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteReshapeParser : public TfliteNodeParser { | |||||
| TfliteReshapeParser() : TfliteNodeParser("Reshape") {} | TfliteReshapeParser() : TfliteNodeParser("Reshape") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,7 +23,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -87,7 +88,7 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| attr->preserveAspectRatio = false; | attr->preserveAspectRatio = false; | ||||
| auto tfliteResizeTensorIndex = tflite_op->inputs[1]; | auto tfliteResizeTensorIndex = tflite_op->inputs[1]; | ||||
| const auto &shape_tensor = tflite_model->subgraphs[0]->tensors[tfliteResizeTensorIndex]; | |||||
| const auto &shape_tensor = tflite_subgraph->tensors[tfliteResizeTensorIndex]; | |||||
| if (shape_tensor == nullptr) { | if (shape_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "shape_tensor is null"; | MS_LOG(ERROR) << "shape_tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -109,14 +110,11 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| op->primitive->value.type = schema::PrimitiveType_Resize; | op->primitive->value.type = schema::PrimitiveType_Resize; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| if (buffData == nullptr) { | if (buffData == nullptr) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteResizeParser : public TfliteNodeParser { | |||||
| TfliteResizeParser() : TfliteNodeParser("node_name") {} | TfliteResizeParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteResizeBilinearParser : public TfliteResizeParser { | class TfliteResizeBilinearParser : public TfliteResizeParser { | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteReverseParser"; | MS_LOG(DEBUG) << "parse TfliteReverseParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -40,7 +41,7 @@ STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->axis)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->axis)) { | |||||
| MS_LOG(ERROR) << "get reverse -> axis failed"; | MS_LOG(ERROR) << "get reverse -> axis failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -48,10 +49,8 @@ STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| op->primitive->value.type = schema::PrimitiveType_Reverse; | op->primitive->value.type = schema::PrimitiveType_Reverse; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteReverseParser : public TfliteNodeParser { | |||||
| TfliteReverseParser() : TfliteNodeParser("reverse") {} | TfliteReverseParser() : TfliteNodeParser("reverse") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -24,7 +24,9 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteReverseSequenceParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteReverseSequenceParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteReverseSequenceParser"; | MS_LOG(DEBUG) << "parse TfliteReverseSequenceParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -53,12 +55,9 @@ STATUS TfliteReverseSequenceParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| op->primitive->value.type = schema::PrimitiveType_ReverseSequence; | op->primitive->value.type = schema::PrimitiveType_ReverseSequence; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteReverseSequenceParser : public TfliteNodeParser { | |||||
| TfliteReverseSequenceParser() : TfliteNodeParser("ReverseSequence") {} | TfliteReverseSequenceParser() : TfliteNodeParser("ReverseSequence") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -24,7 +24,8 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteScatterNdParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteScatterNdParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteScatterNdParser"; | MS_LOG(DEBUG) << "parse TfliteScatterNdParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -52,14 +53,10 @@ STATUS TfliteScatterNdParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| // in tflite, kIndices = 0, kUpdates = 1, kShape = 2 | // in tflite, kIndices = 0, kUpdates = 1, kShape = 2 | ||||
| // in mslite, kScatterShapeIndex = 0, kScatterIndicesIndex = 1, kScatterUpdateIndex = 2; | // in mslite, kScatterShapeIndex = 0, kScatterIndicesIndex = 1, kScatterUpdateIndex = 2; | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteScatterNdParser : public TfliteNodeParser { | |||||
| TfliteScatterNdParser() : TfliteNodeParser("ScatterNd") {} | TfliteScatterNdParser() : TfliteNodeParser("ScatterNd") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteShapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteShapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteShapeParser"; | MS_LOG(DEBUG) << "parse TfliteShapeParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -43,10 +44,8 @@ STATUS TfliteShapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq | |||||
| op->primitive->value.type = schema::PrimitiveType_Shape; | op->primitive->value.type = schema::PrimitiveType_Shape; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteShapeParser : public TfliteNodeParser { | |||||
| TfliteShapeParser() : TfliteNodeParser("Shape") {} | TfliteShapeParser() : TfliteNodeParser("Shape") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteSkipGramParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteSkipGramParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSkipGramParser"; | MS_LOG(DEBUG) << "parse TfliteSkipGramParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -52,10 +53,8 @@ STATUS TfliteSkipGramParser::Parse(TfliteTensorsInfo *tensors_info, const std::u | |||||
| op->primitive->value.type = schema::PrimitiveType_SkipGram; | op->primitive->value.type = schema::PrimitiveType_SkipGram; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteSkipGramParser : public TfliteNodeParser { | |||||
| TfliteSkipGramParser() : TfliteNodeParser("SkipGram") {} | TfliteSkipGramParser() : TfliteNodeParser("SkipGram") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSliceParser"; | MS_LOG(DEBUG) << "parse TfliteSliceParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -42,11 +43,11 @@ STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq | |||||
| attr->format = schema::Format::Format_NHWC; | attr->format = schema::Format::Format_NHWC; | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->begin)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->begin)) { | |||||
| MS_LOG(ERROR) << "get slice -> begin failed"; | MS_LOG(ERROR) << "get slice -> begin failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->size)) { | |||||
| if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->size)) { | |||||
| MS_LOG(ERROR) << "get slice -> size failed"; | MS_LOG(ERROR) << "get slice -> size failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -59,10 +60,8 @@ STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq | |||||
| op->primitive->value.type = schema::PrimitiveType_Slice; | op->primitive->value.type = schema::PrimitiveType_Slice; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteSliceParser : public TfliteNodeParser { | |||||
| TfliteSliceParser() : TfliteNodeParser("Slice") {} | TfliteSliceParser() : TfliteNodeParser("Slice") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteSoftmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteSoftmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSoftmaxParser"; | MS_LOG(DEBUG) << "parse TfliteSoftmaxParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -45,10 +46,8 @@ STATUS TfliteSoftmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| op->primitive->value.type = schema::PrimitiveType_SoftMax; | op->primitive->value.type = schema::PrimitiveType_SoftMax; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteSoftmaxParser : public TfliteNodeParser { | |||||
| TfliteSoftmaxParser() : TfliteNodeParser("Softmax") {} | TfliteSoftmaxParser() : TfliteNodeParser("Softmax") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -24,7 +24,9 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteSpaceToBatchNDParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteSpaceToBatchNDParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSpaceToBatchNDParser"; | MS_LOG(DEBUG) << "parse TfliteSpaceToBatchNDParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -42,12 +44,11 @@ STATUS TfliteSpaceToBatchNDParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, | |||||
| attr->blockShape)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->blockShape)) { | |||||
| MS_LOG(ERROR) << "get spaceToBatchND -> blockShape failed"; | MS_LOG(ERROR) << "get spaceToBatchND -> blockShape failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->paddings)) { | |||||
| if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->paddings)) { | |||||
| MS_LOG(ERROR) << "get spaceToBatchND -> paddings failed"; | MS_LOG(ERROR) << "get spaceToBatchND -> paddings failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -55,10 +56,8 @@ STATUS TfliteSpaceToBatchNDParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| op->primitive->value.type = schema::PrimitiveType_SpaceToBatchND; | op->primitive->value.type = schema::PrimitiveType_SpaceToBatchND; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteSpaceToBatchNDParser : public TfliteNodeParser { | |||||
| TfliteSpaceToBatchNDParser() : TfliteNodeParser("SpaceToBatchND") {} | TfliteSpaceToBatchNDParser() : TfliteNodeParser("SpaceToBatchND") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -24,7 +24,8 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteSpaceToDepthParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteSpaceToDepthParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSpaceToDepthParser"; | MS_LOG(DEBUG) << "parse TfliteSpaceToDepthParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -53,10 +54,8 @@ STATUS TfliteSpaceToDepthParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| op->primitive->value.type = schema::PrimitiveType_SpaceToDepth; | op->primitive->value.type = schema::PrimitiveType_SpaceToDepth; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteSpaceToDepthParser : public TfliteNodeParser { | |||||
| TfliteSpaceToDepthParser() : TfliteNodeParser("SpaceToDepth") {} | TfliteSpaceToDepthParser() : TfliteNodeParser("SpaceToDepth") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -24,7 +24,8 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteSparseToDenseParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteSparseToDenseParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSparseToDenseParser"; | MS_LOG(DEBUG) << "parse TfliteSparseToDenseParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -46,16 +47,11 @@ STATUS TfliteSparseToDenseParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| op->primitive->value.type = schema::PrimitiveType_SparseToDense; | op->primitive->value.type = schema::PrimitiveType_SparseToDense; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[3], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[3], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,7 +30,8 @@ class TfliteSparseToDenseParser : public TfliteNodeParser { | |||||
| TfliteSparseToDenseParser() : TfliteNodeParser("SparseToDense") {} | TfliteSparseToDenseParser() : TfliteNodeParser("SparseToDense") {} | ||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSplitParser"; | MS_LOG(DEBUG) << "parse TfliteSplitParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -47,13 +48,13 @@ STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq | |||||
| } | } | ||||
| auto num_splits = tflite_attr->num_splits; | auto num_splits = tflite_attr->num_splits; | ||||
| const auto &shape_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[1]]; | |||||
| const auto &shape_tensor = tflite_subgraph->tensors[tflite_op->inputs[1]]; | |||||
| if (shape_tensor == nullptr) { | if (shape_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "shape_tensor is null"; | MS_LOG(ERROR) << "shape_tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const auto tensor_shape = shape_tensor->shape; | const auto tensor_shape = shape_tensor->shape; | ||||
| const auto &axis_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; | |||||
| const auto &axis_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; | |||||
| if (axis_tensor == nullptr) { | if (axis_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "axis_tensor is null"; | MS_LOG(ERROR) << "axis_tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -81,11 +82,9 @@ STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq | |||||
| op->primitive->value.type = schema::PrimitiveType_Split; | op->primitive->value.type = schema::PrimitiveType_Split; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| for (size_t i = 0; i < tflite_op->outputs.size(); i++) { | for (size_t i = 0; i < tflite_op->outputs.size(); i++) { | ||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||