Browse Source

use true type

pull/372/head
CLAY-panjw 4 years ago
parent
commit
03f2f43926
2 changed files with 0 additions and 96 deletions
  1. +0
    -48
      tests/ut/parser/graph_builder_utils.cc
  2. +0
    -48
      tests/ut/parser/graph_builder_utils.h

+ 0
- 48
tests/ut/parser/graph_builder_utils.cc View File

@@ -1,48 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph_builder_utils.h"

#include "graph/utils/graph_utils.h"

namespace ge {
namespace ut {
NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format,
DataType data_type, std::vector<int64_t> shape) {
auto tensor_desc = std::make_shared<GeTensorDesc>();
tensor_desc->SetShape(GeShape(std::move(shape)));
tensor_desc->SetFormat(format);
tensor_desc->SetDataType(data_type);

auto op_desc = std::make_shared<OpDesc>(name, type);
for (int i = 0; i < in_cnt; ++i) {
op_desc->AddInputDesc(tensor_desc->Clone());
}
for (int i = 0; i < out_cnt; ++i) {
op_desc->AddOutputDesc(tensor_desc->Clone());
}

return graph_->AddNode(op_desc);
}
void GraphBuilder::AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx) {
GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx));
}
// void GraphBuilder::AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node) {
// GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor());
// }

} // namespace ut
} // namespace ge

+ 0
- 48
tests/ut/parser/graph_builder_utils.h View File

@@ -1,48 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_
#define MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_

#include <string>
#include <vector>

#include "graph/compute_graph.h"
#include "graph/graph.h"
#include "graph/node.h"

namespace ge {
namespace ut {
class GraphBuilder {
public:
explicit GraphBuilder(const std::string &name) { graph_ = std::make_shared<ComputeGraph>(name); }
NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt,
Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT,
std::vector<int64_t> shape = {1, 1, 224, 224});
void AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx);
void AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node);
ComputeGraphPtr GetGraph() {
graph_->TopologicalSorting();
return graph_;
}

private:
ComputeGraphPtr graph_;
};
} // namespace ut
} // namespace ge

#endif // MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_

Loading…
Cancel
Save