Browse Source

modified: tests/ut/ge/CMakeLists.txt

new file:   tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc
pull/931/head
zhaoxinxin 5 years ago
parent
commit
78d1dc4e9c
2 changed files with 79 additions and 1 deletions
  1. +2
    -1
      tests/ut/ge/CMakeLists.txt
  2. +77
    -0
      tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc

+ 2
- 1
tests/ut/ge/CMakeLists.txt View File

@@ -36,7 +36,7 @@ set(PROTO_LIST
"${GE_CODE_DIR}/metadef/proto/proto_inner/ge_onnx.proto"
)

protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})
protobuf_generate((ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})

# include directories
include_directories(${CMAKE_CURRENT_LIST_DIR})
@@ -694,6 +694,7 @@ set(MULTI_PARTS_TEST_FILES
"graph/variable_accelerate_ctrl_unittest.cc"
"graph/build/logical_stream_allocator_unittest.cc"
"graph/build/mem_assigner_unittest.cc"
"graph/preprocess/graph_preprocess_unittest.cc"
"session/omg_omg_unittest.cc"
)



+ 77
- 0
tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc View File

@@ -0,0 +1,77 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <memory>

#include "common/ge_inner_error_codes.h"
#include "common/types.h"
#include "common/util.h"
#include "passes/graph_builder_utils.h"

#define private public
#define protected public
#include "graph/preprocess/graph_preprocess.h"
#include "ge/ge_api.h"
#undef private
#undef protected

using namespace std;
namespace ge {
class UtestGraphPreproces : public testing::Test {
protected:
void SetUp() {
map<string, string> options;
ge::Status ret = ge::GEInitialize();
EXPECT_EQ(ret, ge::SUCCESS);
}
void TearDown() {
ge::Status ret = ge::GEFinalize();
EXPECT_EQ(ret, ge::SUCCESS);
}
};

ComputeGraphPtr BuildGraph1(){
auto builder = ut::GraphBuilder("g1");
auto data1 = builder.AddNode("data1",DATA,0,1);
return builder.GetGraph();
}

TEST_F(UtestGraphPreproces, test_dynamic_input_shape_parse) {
ge::GraphPrepare graph_prepare;
graph_prepare.compute_graph_ = BuildGraph1();
// prepare user_input & graph option
ge::GeTensorDesc tensor1;
tensor.SetFormat(ge::FORMAT_NCHW);
tensor.SetShape(ge::GeShape({3, 12, 5, 5}));
tensor.SetDataType(ge::DT_FLOAT);
GeTensor input1 = std::make_shared<GeTensor>(tensor1);
std::vector<GeTensor> user_input = {input1};
std::map<string,string> graph_option = {{"ge.exec.dynamicGraphExecuteMode","dynamic_execute"},
{"ge.exec.dataInputsShapeRange","[3,1~20,2~10,5]"}};
auto ret = graph_prepare.UpdateInput(input1, graph_option);
EXPECT_EQ(ret, ge::SUCCESS);
// check data node output shape_range and shape
auto data_node = graph_prepare.compute_graph_->FindNode("data1");
auto data_output_desc = data_node->GetOpDesc()->GetOutputDescPtr(0);
vector<int64_t> expect_shape = {3,-1,-1,5};
auto result_shape = data_output_desc->GetShape();
EXPECT_EQ(result_shape.GetDimNum(), expect_shape.size());
for(size_t i =0; i< expect_shape.size(); ++i){
EXPECT_EQ(result_shape.GetDim(i), expect_shape.at(i));
}
}
}

Loading…
Cancel
Save