From 0facfed4ff4179d4da2909efbc2de8a02c88c80a Mon Sep 17 00:00:00 2001 From: xuanyue Date: Thu, 10 Dec 2020 15:01:50 +0800 Subject: [PATCH] onnx replace transpose with graph input --- .../graph/onnx_inputs_adjust_pass.cc | 54 ++++++++++++++++++- .../optimizer/graph/onnx_inputs_adjust_pass.h | 1 + 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc index 0f27c33d53..70a5707736 100644 --- a/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc @@ -391,7 +391,7 @@ STATUS OnnxInputAdjustOpPass::AdjustCast(const CNodePtr &cnode) { MS_LOG(ERROR) << "cnode input0 is not a valuenode."; return lite::RET_ERROR; } - MS_ASSERT(value_node->value != nullptr); + MS_ASSERT(value_node->value() != nullptr); auto primitive_c = value_node->value()->cast(); if (primitive_c == nullptr) { MS_LOG(ERROR) << "cnode has no primitive_c."; @@ -461,6 +461,56 @@ STATUS OnnxInputAdjustOpPass::ReplaceConstant(const FuncGraphPtr &func_graph, co return lite::RET_OK; } +STATUS OnnxInputAdjustOpPass::ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(cnode != nullptr); + if (cnode->inputs().size() != 2) { + MS_LOG(ERROR) << "onnx transpose input size is 1, now is " << cnode->inputs().size() - 1; + return lite::RET_ERROR; + } + auto anf_node = cnode->input(1); + MS_ASSERT(anf_node != nullptr); + auto param_node = anf_node->cast(); + if (param_node == nullptr || !param_node->has_default()) { + MS_LOG(DEBUG) << "input is not graph input"; + return lite::RET_OK; + } + MS_ASSERT(param_node->abstract() != nullptr && param_node->abstract()->GetShapeTrack() != nullptr); + auto shape_ptr = param_node->abstract()->GetShapeTrack()->cast(); + if (shape_ptr == nullptr) { + MS_LOG(ERROR) << "shape is nullptr."; + } + auto shape_vector = shape_ptr->shape(); + if (shape_vector.size() != 4) { + MS_LOG(DEBUG) << "only adjust 4 dims graph input."; + return lite::RET_OK; + } + auto prim_anf = cnode->input(0); + if (prim_anf == nullptr || !utils::isa(prim_anf)) { + MS_LOG(ERROR) << "cnode input0 is invalid."; + return lite::RET_ERROR; + } + auto value_node = prim_anf->cast(); + MS_ASSERT(value_node->value() != nullptr); + auto prim = value_node->value()->cast(); + MS_ASSERT(prim != nullptr && prim->primitiveT() != nullptr && prim->primitiveT()->value.value != nullptr); + auto attr = reinterpret_cast(prim->primitiveT()->value.value); + auto perm = attr->perm; + std::vector transpose_attr; + std::transform(perm.begin(), perm.end(), std::back_inserter(transpose_attr), + [](const int &val) { return val < 0 ? val + 4 : val; }); + if (transpose_attr[0] == 0 && transpose_attr[1] == 3 && transpose_attr[2] == 1) { + auto channel = shape_vector[3]; + shape_vector.pop_back(); + shape_vector.insert(shape_vector.begin() + 1, channel); + param_node->abstract()->set_shape(std::make_shared(shape_vector)); + auto manager = func_graph->manager(); + MS_ASSERT(manager != nullptr); + manager->Replace(cnode, param_node); + } + return lite::RET_OK; +} + bool OnnxInputAdjustOpPass::Run(const FuncGraphPtr &func_graph) { MS_ASSERT(func_graph != nullptr); auto manager = Manage(func_graph, true); @@ -497,6 +547,8 @@ bool OnnxInputAdjustOpPass::Run(const FuncGraphPtr &func_graph) { status = ReplaceConstant(func_graph, cnode); } else if (type == schema::PrimitiveType_Cast) { status = AdjustCast(cnode); + } else if (type == schema::PrimitiveType_Transpose) { + status = ReplaceTransposeWithGraphInput(func_graph, cnode); } if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { MS_LOG(ERROR) << "adjust input pass is failed."; diff --git a/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h b/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h index 66ebd6a2dc..b8b11097e0 100644 --- a/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h +++ b/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h @@ -41,6 +41,7 @@ class OnnxInputAdjustOpPass : public Pass { STATUS AdjustTile(const CNodePtr &cnode); STATUS AdjustCast(const CNodePtr &cnode); STATUS ReplaceConstant(const FuncGraphPtr &func_graph, const CNodePtr &cnode); + STATUS ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); bool Run(const FuncGraphPtr &func_graph) override; private: