| @@ -64,11 +64,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| // for now - trainning is not supporting fuse operations | |||
| if (config != nullptr && !config->trainModel) { | |||
| // remove quantdtype when awaretraining | |||
| if (config->fmk == lite::converter::FmkType_ONNX) { | |||
| auto remove_identity_pass = std::make_shared<opt::RemoveIdentityOpPass>(); | |||
| remove_identity_pass->SetFmkType(config->fmk); | |||
| pm->AddPass(remove_identity_pass); | |||
| } | |||
| pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>()); | |||
| pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ConvScaleFusion>()); | |||
| @@ -181,6 +181,33 @@ STATUS TfliteCustomParser::FftImag(const std::vector<uint8_t> &custom_attr, sche | |||
| return RET_OK; | |||
| } | |||
| STATUS TfliteCustomParser::Identity(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op) { | |||
| std::unique_ptr<schema::IdentityT> attr = std::make_unique<schema::IdentityT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Identity; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| } | |||
| STATUS TfliteCustomParser::BatchMatMul(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op) { | |||
| std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->broadcast = false; | |||
| attr->transposeA = false; | |||
| attr->transposeB = false; | |||
| op->primitive->value.type = schema::PrimitiveType_MatMul; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| } | |||
| STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| @@ -216,6 +243,10 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||
| status = FftReal(custom_attr, op, tflite_op); | |||
| } else if (custom_type == "FlexImag") { | |||
| status = FftImag(custom_attr, op, tflite_op); | |||
| } else if (custom_type == "FlexIdentityN" || custom_type == "FlexIdentity") { | |||
| status = Identity(custom_attr, op, tflite_op); | |||
| } else if (custom_type == "FlexBatchMatMul") { | |||
| status = BatchMatMul(custom_attr, op, tflite_op); | |||
| } else { | |||
| MS_LOG(ERROR) << "the custom op hasn't been supported now"; | |||
| status = RET_NOT_FIND_OP; | |||
| @@ -60,6 +60,12 @@ class TfliteCustomParser : public TfliteNodeParser { | |||
| STATUS FftImag(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | |||
| STATUS Identity(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | |||
| STATUS BatchMatMul(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,59 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tflite/tflite_matmul_parser.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <map> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteMatMulParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteMatMulParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsBatchMatMulOptions(); | |||
| attr->transposeA = tflite_attr->adj_x; | |||
| attr->transposeB = tflite_attr->adj_y; | |||
| attr->broadcast = false; | |||
| op->primitive->value.type = schema::PrimitiveType_MatMul; | |||
| op->primitive->value.value = attr.release(); | |||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| TfliteNodeRegister g_tfliteMatMulParser("MatMul", new TfliteMatMulParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MATMUL_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MATMUL_PARSER_H | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TfliteMatMulParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteMatMulParser() : TfliteNodeParser("MatMul") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SLICE_PARSER_H | |||
| @@ -109,10 +109,11 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||
| status = (status == RET_OK ? RET_NOT_FIND_OP : status); | |||
| continue; | |||
| } | |||
| if (status == RET_OK) { | |||
| status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, tflite_subgraph, op.get()); | |||
| if (status != RET_OK) { | |||
| if (status == RET_NOT_FIND_OP) { | |||
| if (status == RET_OK || op_type == "Custom") { | |||
| int status_node = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, tflite_subgraph, op.get()); | |||
| status = (status == RET_OK ? status_node : status); | |||
| if (status_node != RET_OK) { | |||
| if (status_node == RET_NOT_FIND_OP) { | |||
| op_type = | |||
| (op_type != "Custom" ? op_type : (tflite_model->operator_codes[tflite_op->opcode_index])->custom_code); | |||
| NoSupportOp::GetInstance()->InsertOp(op_type); | |||
| @@ -121,6 +122,9 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||
| } | |||
| continue; | |||
| } | |||
| if (status != RET_OK) { | |||
| continue; | |||
| } | |||
| sub_graph->nodes.emplace_back(op.release()); | |||
| opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get(); | |||
| tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get(); | |||
| @@ -14,34 +14,96 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/graph/identity_remove_pass.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "mindspore/lite/include/errorcode.h" | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore::opt { | |||
| bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| if (this->fmk_type != lite::converter::FmkType_ONNX) { | |||
| MS_LOG(INFO) << "The framework type of model should be onnx."; | |||
| return RET_OK; | |||
| int RemoveIdentityOpPass::ReplaceIdentity(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { | |||
| if (!utils::isa<CNodePtr>(anf_node)) { | |||
| MS_LOG(DEBUG) << "anf node is node a cnode."; | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto type = opt::GetCNodeType(anf_node); | |||
| if (type != schema::PrimitiveType_Identity) { | |||
| MS_LOG(DEBUG) << "anf node is not a identity node."; | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto identity_cnode = anf_node->cast<CNodePtr>(); | |||
| if (identity_cnode->inputs().size() != lite::kDoubleNum) { | |||
| MS_LOG(DEBUG) << "The node inputs size is bigger than 1"; | |||
| remove_cnode_.insert(anf_node); | |||
| return lite::RET_NO_CHANGE; | |||
| } else { | |||
| bool replace_succ = manager->Replace(anf_node, identity_cnode->input(1)); | |||
| if (!replace_succ) { | |||
| MS_LOG(ERROR) << "replace identity failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { | |||
| if (!utils::isa<CNodePtr>(anf_node)) { | |||
| MS_LOG(DEBUG) << "anf node is node a cnode."; | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto type = opt::GetCNodeType(anf_node); | |||
| if (type != schema::PrimitiveType_TupleGetItem) { | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| if (cnode->inputs().size() != 3) { | |||
| MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << cnode->inputs().size(); | |||
| return RET_ERROR; | |||
| } | |||
| type = opt::GetCNodeType(cnode->input(1)); | |||
| if (type != schema::PrimitiveType_Identity) { | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto get_item_input_cnode = cnode->input(1)->cast<CNodePtr>(); | |||
| auto index_vnode = cnode->input(2); | |||
| if (!utils::isa<ValueNode>(index_vnode)) { | |||
| MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode"; | |||
| return lite::RET_ERROR; | |||
| } | |||
| int index = lite::CastToInt(index_vnode->cast<ValueNodePtr>()->value(), false).front(); | |||
| int input_cnode_inputs_size = get_item_input_cnode->inputs().size(); | |||
| if ((index + 1) >= input_cnode_inputs_size) { | |||
| MS_LOG(ERROR) << "value node index is out of range."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| bool replace_succ = manager->Replace(anf_node, get_item_input_cnode->input(index + 1)); | |||
| if (!replace_succ) { | |||
| MS_LOG(ERROR) << "replace identity failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto manager = func_graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| int status = RET_OK; | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| auto type = opt::GetCNodeType(node); | |||
| if (type != schema::PrimitiveType_Identity) { | |||
| continue; | |||
| if (type == schema::PrimitiveType_Identity) { | |||
| status = ReplaceIdentity(node, manager); | |||
| } else if (type == schema::PrimitiveType_TupleGetItem) { | |||
| status = ReplaceTupleGetItem(node, manager); | |||
| } | |||
| auto identity_cnode = node->cast<CNodePtr>(); | |||
| if (identity_cnode->inputs().size() != lite::kDoubleNum) { | |||
| MS_LOG(ERROR) << "The `node input is a single tensor"; | |||
| return RET_ERROR; | |||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "remove identity pass is failed."; | |||
| return false; | |||
| } | |||
| manager->Replace(node, identity_cnode->input(1)); | |||
| } | |||
| for (auto &node : remove_cnode_) { | |||
| func_graph->DropNode(node); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -17,8 +17,10 @@ | |||
| #ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_ | |||
| #define MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_ | |||
| #include <string> | |||
| #include <set> | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| namespace mindspore::opt { | |||
| @@ -26,11 +28,12 @@ class RemoveIdentityOpPass : public Pass { | |||
| public: | |||
| RemoveIdentityOpPass() : Pass("remove_identity_pass") {} | |||
| ~RemoveIdentityOpPass() override = default; | |||
| void SetFmkType(FmkType fmkType) { this->fmk_type = fmkType; } | |||
| int ReplaceIdentity(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); | |||
| int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| private: | |||
| FmkType fmk_type = lite::converter::FmkType_ONNX; | |||
| std::set<AnfNodePtr> remove_cnode_; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_ | |||