Browse Source

!9773 [lite] onnx replace transpose with graph input

From: @xu_anyue
Reviewed-by: @hangangqiang,@zhang_xue_tong
Signed-off-by: @hangangqiang
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
820ba956a5
2 changed files with 54 additions and 1 deletions
  1. +53
    -1
      mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc
  2. +1
    -0
      mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h

+ 53
- 1
mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc View File

@@ -391,7 +391,7 @@ STATUS OnnxInputAdjustOpPass::AdjustCast(const CNodePtr &cnode) {
MS_LOG(ERROR) << "cnode input0 is not a valuenode.";
return lite::RET_ERROR;
}
MS_ASSERT(value_node->value != nullptr);
MS_ASSERT(value_node->value() != nullptr);
auto primitive_c = value_node->value()->cast<PrimitiveCPtr>();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "cnode has no primitive_c.";
@@ -461,6 +461,56 @@ STATUS OnnxInputAdjustOpPass::ReplaceConstant(const FuncGraphPtr &func_graph, co
return lite::RET_OK;
}

STATUS OnnxInputAdjustOpPass::ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(cnode != nullptr);
if (cnode->inputs().size() != 2) {
MS_LOG(ERROR) << "onnx transpose input size is 1, now is " << cnode->inputs().size() - 1;
return lite::RET_ERROR;
}
auto anf_node = cnode->input(1);
MS_ASSERT(anf_node != nullptr);
auto param_node = anf_node->cast<ParameterPtr>();
if (param_node == nullptr || !param_node->has_default()) {
MS_LOG(DEBUG) << "input is not graph input";
return lite::RET_OK;
}
MS_ASSERT(param_node->abstract() != nullptr && param_node->abstract()->GetShapeTrack() != nullptr);
auto shape_ptr = param_node->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>();
if (shape_ptr == nullptr) {
MS_LOG(ERROR) << "shape is nullptr.";
}
auto shape_vector = shape_ptr->shape();
if (shape_vector.size() != 4) {
MS_LOG(DEBUG) << "only adjust 4 dims graph input.";
return lite::RET_OK;
}
auto prim_anf = cnode->input(0);
if (prim_anf == nullptr || !utils::isa<ValueNodePtr>(prim_anf)) {
MS_LOG(ERROR) << "cnode input0 is invalid.";
return lite::RET_ERROR;
}
auto value_node = prim_anf->cast<ValueNodePtr>();
MS_ASSERT(value_node->value() != nullptr);
auto prim = value_node->value()->cast<PrimitiveCPtr>();
MS_ASSERT(prim != nullptr && prim->primitiveT() != nullptr && prim->primitiveT()->value.value != nullptr);
auto attr = reinterpret_cast<schema::TransposeT *>(prim->primitiveT()->value.value);
auto perm = attr->perm;
std::vector<int> transpose_attr;
std::transform(perm.begin(), perm.end(), std::back_inserter(transpose_attr),
[](const int &val) { return val < 0 ? val + 4 : val; });
if (transpose_attr[0] == 0 && transpose_attr[1] == 3 && transpose_attr[2] == 1) {
auto channel = shape_vector[3];
shape_vector.pop_back();
shape_vector.insert(shape_vector.begin() + 1, channel);
param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vector));
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
manager->Replace(cnode, param_node);
}
return lite::RET_OK;
}

bool OnnxInputAdjustOpPass::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto manager = Manage(func_graph, true);
@@ -497,6 +547,8 @@ bool OnnxInputAdjustOpPass::Run(const FuncGraphPtr &func_graph) {
status = ReplaceConstant(func_graph, cnode);
} else if (type == schema::PrimitiveType_Cast) {
status = AdjustCast(cnode);
} else if (type == schema::PrimitiveType_Transpose) {
status = ReplaceTransposeWithGraphInput(func_graph, cnode);
}
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "adjust input pass is failed.";


+ 1
- 0
mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h View File

@@ -41,6 +41,7 @@ class OnnxInputAdjustOpPass : public Pass {
STATUS AdjustTile(const CNodePtr &cnode);
STATUS AdjustCast(const CNodePtr &cnode);
STATUS ReplaceConstant(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
bool Run(const FuncGraphPtr &func_graph) override;

private:


Loading…
Cancel
Save