From 196fe87ffebd16fcf9664d7cc6b7ed0cd244a8be Mon Sep 17 00:00:00 2001 From: zhengyuanhua Date: Thu, 11 Mar 2021 16:25:19 +0800 Subject: [PATCH] modify input shape check --- ge/graph/preprocess/graph_preprocess.cc | 11 +++++------ .../graph/preprocess/graph_preprocess_unittest.cc | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index db17e091..84e57a9b 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -1763,13 +1763,12 @@ Status GraphPrepare::CheckUserInput(const std::vector &user_input) { GeTensorDesc desc(user_input[index].GetTensorDesc()); for (size_t i = 0; i < desc.GetShape().GetDimNum(); ++i) { - if (desc.GetShape().GetDim(i) < 0) { - std::string situation = "data dim[" + std::to_string(i) + "][" + - std::to_string(desc.GetShape().GetDim(i)) + "]" ; - std::string reason = "it need >= 0"; + int64_t dim = desc.GetShape().GetDim(i); + if (dim < 0 && dim != UNKNOWN_DIM) { + std::string situation = "data dim[" + std::to_string(i) + "][" + std::to_string(dim) + "]" ; + std::string reason = "it need >= -1"; ErrorManager::GetInstance().ATCReportErrMessage("E19025", {"situation", "reason"}, {situation, reason}); - GELOGE(GE_GRAPH_INIT_FAILED, "data dim %zu is not supported, need >= 0, real:%ld.", i, - desc.GetShape().GetDim(i)); + GELOGE(GE_GRAPH_INIT_FAILED, "data dim %zu is not supported, need >= -1, real:%ld.", i, dim); return GE_GRAPH_INIT_FAILED; } } diff --git a/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc b/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc index 2f149761..69192631 100644 --- a/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc +++ b/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc @@ -74,4 +74,18 @@ TEST_F(UtestGraphPreproces, test_dynamic_input_shape_parse) { EXPECT_EQ(result_shape.GetDim(i), expect_shape.at(i)); } } + +TEST_F(UtestGraphPreproces, test_check_user_input) { + ge::GraphPrepare graph_prepare; + graph_prepare.compute_graph_ = BuildGraph1(); + + vector dim = {2, -3}; + GeTensor tensor; + tensor.SetTensorDesc(GeTensorDesc(GeShape(dim))); + std::vector user_input; + user_input.emplace_back(tensor); + + Status ret = graph_prepare.CheckUserInput(user_input); + EXPECT_EQ(ret, GE_GRAPH_INIT_FAILED); +} } \ No newline at end of file