|
|
|
@@ -391,7 +391,7 @@ STATUS OnnxInputAdjustOpPass::AdjustCast(const CNodePtr &cnode) { |
|
|
|
MS_LOG(ERROR) << "cnode input0 is not a valuenode."; |
|
|
|
return lite::RET_ERROR; |
|
|
|
} |
|
|
|
MS_ASSERT(value_node->value != nullptr); |
|
|
|
MS_ASSERT(value_node->value() != nullptr); |
|
|
|
auto primitive_c = value_node->value()->cast<PrimitiveCPtr>(); |
|
|
|
if (primitive_c == nullptr) { |
|
|
|
MS_LOG(ERROR) << "cnode has no primitive_c."; |
|
|
|
@@ -461,6 +461,56 @@ STATUS OnnxInputAdjustOpPass::ReplaceConstant(const FuncGraphPtr &func_graph, co |
|
|
|
return lite::RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS OnnxInputAdjustOpPass::ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { |
|
|
|
MS_ASSERT(func_graph != nullptr); |
|
|
|
MS_ASSERT(cnode != nullptr); |
|
|
|
if (cnode->inputs().size() != 2) { |
|
|
|
MS_LOG(ERROR) << "onnx transpose input size is 1, now is " << cnode->inputs().size() - 1; |
|
|
|
return lite::RET_ERROR; |
|
|
|
} |
|
|
|
auto anf_node = cnode->input(1); |
|
|
|
MS_ASSERT(anf_node != nullptr); |
|
|
|
auto param_node = anf_node->cast<ParameterPtr>(); |
|
|
|
if (param_node == nullptr || !param_node->has_default()) { |
|
|
|
MS_LOG(DEBUG) << "input is not graph input"; |
|
|
|
return lite::RET_OK; |
|
|
|
} |
|
|
|
MS_ASSERT(param_node->abstract() != nullptr && param_node->abstract()->GetShapeTrack() != nullptr); |
|
|
|
auto shape_ptr = param_node->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>(); |
|
|
|
if (shape_ptr == nullptr) { |
|
|
|
MS_LOG(ERROR) << "shape is nullptr."; |
|
|
|
} |
|
|
|
auto shape_vector = shape_ptr->shape(); |
|
|
|
if (shape_vector.size() != 4) { |
|
|
|
MS_LOG(DEBUG) << "only adjust 4 dims graph input."; |
|
|
|
return lite::RET_OK; |
|
|
|
} |
|
|
|
auto prim_anf = cnode->input(0); |
|
|
|
if (prim_anf == nullptr || !utils::isa<ValueNodePtr>(prim_anf)) { |
|
|
|
MS_LOG(ERROR) << "cnode input0 is invalid."; |
|
|
|
return lite::RET_ERROR; |
|
|
|
} |
|
|
|
auto value_node = prim_anf->cast<ValueNodePtr>(); |
|
|
|
MS_ASSERT(value_node->value() != nullptr); |
|
|
|
auto prim = value_node->value()->cast<PrimitiveCPtr>(); |
|
|
|
MS_ASSERT(prim != nullptr && prim->primitiveT() != nullptr && prim->primitiveT()->value.value != nullptr); |
|
|
|
auto attr = reinterpret_cast<schema::TransposeT *>(prim->primitiveT()->value.value); |
|
|
|
auto perm = attr->perm; |
|
|
|
std::vector<int> transpose_attr; |
|
|
|
std::transform(perm.begin(), perm.end(), std::back_inserter(transpose_attr), |
|
|
|
[](const int &val) { return val < 0 ? val + 4 : val; }); |
|
|
|
if (transpose_attr[0] == 0 && transpose_attr[1] == 3 && transpose_attr[2] == 1) { |
|
|
|
auto channel = shape_vector[3]; |
|
|
|
shape_vector.pop_back(); |
|
|
|
shape_vector.insert(shape_vector.begin() + 1, channel); |
|
|
|
param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vector)); |
|
|
|
auto manager = func_graph->manager(); |
|
|
|
MS_ASSERT(manager != nullptr); |
|
|
|
manager->Replace(cnode, param_node); |
|
|
|
} |
|
|
|
return lite::RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
bool OnnxInputAdjustOpPass::Run(const FuncGraphPtr &func_graph) { |
|
|
|
MS_ASSERT(func_graph != nullptr); |
|
|
|
auto manager = Manage(func_graph, true); |
|
|
|
@@ -497,6 +547,8 @@ bool OnnxInputAdjustOpPass::Run(const FuncGraphPtr &func_graph) { |
|
|
|
status = ReplaceConstant(func_graph, cnode); |
|
|
|
} else if (type == schema::PrimitiveType_Cast) { |
|
|
|
status = AdjustCast(cnode); |
|
|
|
} else if (type == schema::PrimitiveType_Transpose) { |
|
|
|
status = ReplaceTransposeWithGraphInput(func_graph, cnode); |
|
|
|
} |
|
|
|
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { |
|
|
|
MS_LOG(ERROR) << "adjust input pass is failed."; |
|
|
|
|