|
|
|
@@ -0,0 +1,77 @@ |
|
|
|
/** |
|
|
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
* You may obtain a copy of the License at |
|
|
|
* |
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
* |
|
|
|
* Unless required by applicable law or agreed to in writing, software |
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include <gtest/gtest.h> |
|
|
|
#include <memory> |
|
|
|
|
|
|
|
#include "common/ge_inner_error_codes.h" |
|
|
|
#include "common/types.h" |
|
|
|
#include "common/util.h" |
|
|
|
#include "passes/graph_builder_utils.h" |
|
|
|
|
|
|
|
#define private public |
|
|
|
#define protected public |
|
|
|
#include "graph/preprocess/graph_preprocess.h" |
|
|
|
#include "ge/ge_api.h" |
|
|
|
#undef private |
|
|
|
#undef protected |
|
|
|
|
|
|
|
using namespace std; |
|
|
|
namespace ge { |
|
|
|
class UtestGraphPreproces : public testing::Test { |
|
|
|
protected: |
|
|
|
void SetUp() { |
|
|
|
map<string, string> options; |
|
|
|
ge::Status ret = ge::GEInitialize(); |
|
|
|
EXPECT_EQ(ret, ge::SUCCESS); |
|
|
|
} |
|
|
|
void TearDown() { |
|
|
|
ge::Status ret = ge::GEFinalize(); |
|
|
|
EXPECT_EQ(ret, ge::SUCCESS); |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
ComputeGraphPtr BuildGraph1(){ |
|
|
|
auto builder = ut::GraphBuilder("g1"); |
|
|
|
auto data1 = builder.AddNode("data1",DATA,0,1); |
|
|
|
return builder.GetGraph(); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(UtestGraphPreproces, test_dynamic_input_shape_parse) { |
|
|
|
ge::GraphPrepare graph_prepare; |
|
|
|
graph_prepare.compute_graph_ = BuildGraph1(); |
|
|
|
// prepare user_input & graph option |
|
|
|
ge::GeTensorDesc tensor1; |
|
|
|
tensor.SetFormat(ge::FORMAT_NCHW); |
|
|
|
tensor.SetShape(ge::GeShape({3, 12, 5, 5})); |
|
|
|
tensor.SetDataType(ge::DT_FLOAT); |
|
|
|
GeTensor input1 = std::make_shared<GeTensor>(tensor1); |
|
|
|
std::vector<GeTensor> user_input = {input1}; |
|
|
|
std::map<string,string> graph_option = {{"ge.exec.dynamicGraphExecuteMode","dynamic_execute"}, |
|
|
|
{"ge.exec.dataInputsShapeRange","[3,1~20,2~10,5]"}}; |
|
|
|
auto ret = graph_prepare.UpdateInput(input1, graph_option); |
|
|
|
EXPECT_EQ(ret, ge::SUCCESS); |
|
|
|
// check data node output shape_range and shape |
|
|
|
auto data_node = graph_prepare.compute_graph_->FindNode("data1"); |
|
|
|
auto data_output_desc = data_node->GetOpDesc()->GetOutputDescPtr(0); |
|
|
|
vector<int64_t> expect_shape = {3,-1,-1,5}; |
|
|
|
auto result_shape = data_output_desc->GetShape(); |
|
|
|
EXPECT_EQ(result_shape.GetDimNum(), expect_shape.size()); |
|
|
|
for(size_t i =0; i< expect_shape.size(); ++i){ |
|
|
|
EXPECT_EQ(result_shape.GetDim(i), expect_shape.at(i)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |