Browse Source

use true types istead of GeAttrValue::

pull/372/head
CLAY-panjw 4 years ago
parent
commit
e3d033ae28
9 changed files with 175 additions and 1904 deletions
  1. +1
    -1843
      parser/tensorflow/graph_optimizer.cc
  2. +2
    -55
      parser/tensorflow/graph_optimizer.h
  3. +0
    -2
      parser/tensorflow/iterator_fusion_pass.cc
  4. +2
    -3
      parser/tensorflow/iterator_fusion_pass.h
  5. +1
    -1
      parser/tensorflow/tensorflow_parser.cc
  6. +2
    -0
      tests/ut/parser/CMakeLists.txt
  7. +48
    -0
      tests/ut/parser/graph_builder_utils.cc
  8. +48
    -0
      tests/ut/parser/graph_builder_utils.h
  9. +71
    -0
      tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc

+ 1
- 1843
parser/tensorflow/graph_optimizer.cc
File diff suppressed because it is too large
View File


+ 2
- 55
parser/tensorflow/graph_optimizer.h View File

@@ -35,68 +35,16 @@ namespace ge {
class ParserGraphOptimizer {
public:
explicit ParserGraphOptimizer(ge::ComputeGraphPtr graph, domi::FrameworkType type = domi::TENSORFLOW)
: graph_(graph), fmktype_(type), local_fmk_op_flag_(false) {}
: graph_(graph), fmktype_(type) {}

~ParserGraphOptimizer() {}

domi::Status Optimize();

domi::Status OptimizeAfterCal();

domi::Status FusionFmkop();

inline bool IsHCOMOp(const string &op_type) {
return (op_type == ge::parser::HCOMALLREDUCE) || (op_type == ge::parser::HCOMALLGATHER) ||
(op_type == ge::parser::HCOMBROADCAST) || (op_type == ge::parser::HCOMSEND) ||
(op_type == ge::parser::HCOMRECEIVE) || (op_type == "HcomReduceScatter");
}

void SetLocalFmkopFlag(bool isLocalFmkopFlag) { local_fmk_op_flag_ = isLocalFmkopFlag; }

const bool GetLocalFmkopFlag() const { return local_fmk_op_flag_; }

void SetFuncBinPath(std::string isFuncBinPath) { func_bin_path_ = isFuncBinPath; }
const std::string GetFuncBinPath() const { return func_bin_path_; }

domi::Status InsertHWCK2FZ(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor,
enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype,
enum ge::Format dstInFormat, enum ge::DataType dstInDatatype);

domi::Status Insert4DTo5DTransOp(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor,
enum ge::Format src_out_format, enum ge::DataType src_out_data_type,
enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type);

domi::Status InsertFZ2HWCK(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor,
enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype,
enum ge::Format dstInFormat, enum ge::DataType dstInDatatype);

domi::Status Insert5DTo4DTransOp(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor,
enum ge::Format src_out_format, enum ge::DataType src_out_data_type,
enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type);

ge::OpDescPtr CreateCastOp(enum ge::DataType input_datatype, enum ge::DataType output_datatype, ge::Format format);

ge::OpDescPtr CreatePermuteOp(enum ge::Format input_format, enum ge::Format output_format);

ge::OpDescPtr CreateTransDataOp(enum ge::Format input_format);

domi::Status NewNodeAddEdges(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, ge::NodePtr first,
ge::NodePtr second, ge::NodePtr third);

domi::Status InsertVar5DTo4D(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor,
enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype,
enum ge::Format dstInFormat, enum ge::DataType dstInDatatype);

ge::OpDescPtr CreateTranslateOp(enum ge::Format inFormat, ge::DataType inDatatype, enum ge::Format outFormat,
ge::DataType outDatatype);

private:
ge::ComputeGraphPtr graph_;
domi::FrameworkType fmktype_;
// local fmkop flag
bool local_fmk_op_flag_;
std::string func_bin_path_;

domi::Status FindFmkNodeCluser(unordered_map<string, vector<ge::NodePtr>> &node_cluser_Map);

domi::Status MarkForFusion(unordered_map<string, vector<ge::NodePtr>> &node_cluser_Map);
@@ -122,7 +70,6 @@ class ParserGraphOptimizer {
vector<ge::InControlAnchorPtr> &input_control_anchors,
vector<ge::OutControlAnchorPtr> &output_control_anchors, ge::NodePtr fusion_node);

domi::Status MakeTfProtoDef();
};
} // namespace ge
#endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZER_H_

+ 0
- 2
parser/tensorflow/iterator_fusion_pass.cc View File

@@ -32,8 +32,6 @@ Status IteratorFusionPass::Run(ge::ComputeGraphPtr graph) {
REPORT_CALL_ERROR("E19999", "New ParserGraphOptimizer failed");
return FAILED;
}

graph_optimizer->SetLocalFmkopFlag(local_fmk_op_flag_);
return graph_optimizer->FusionFmkop();
}
} // namespace ge

+ 2
- 3
parser/tensorflow/iterator_fusion_pass.h View File

@@ -23,8 +23,8 @@
namespace ge {
class IteratorFusionPass : public GraphPass {
public:
IteratorFusionPass(domi::FrameworkType type, bool local_fmk_op_flag)
: fmk_type_(type), local_fmk_op_flag_(local_fmk_op_flag) {}
IteratorFusionPass(domi::FrameworkType type)
: fmk_type_(type) {}

virtual ~IteratorFusionPass() {}

@@ -32,7 +32,6 @@ class IteratorFusionPass : public GraphPass {

private:
domi::FrameworkType fmk_type_;
bool local_fmk_op_flag_;
};
} // namespace ge



+ 1
- 1
parser/tensorflow/tensorflow_parser.cc View File

@@ -2375,7 +2375,7 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto,
ge::parser::PassManager iterator_fusion_pass;
try {
(void)iterator_fusion_pass.AddPass("ParseProto::IteratorFusionPass",
new ge::IteratorFusionPass(domi::TENSORFLOW, false));
new ge::IteratorFusionPass(domi::TENSORFLOW));
} catch (std::bad_alloc &e) {
GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs.");
return INTERNAL_ERROR;


+ 2
- 0
tests/ut/parser/CMakeLists.txt View File

@@ -307,6 +307,7 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework)


set(PARSER_UT_FILES
"graph_builder_utils.cc"
"parser_ut_utils.cc"
"testcase/common/acl_graph_parser_unittest.cc"
"testcase/onnx_parser_testcase/onnx_parser_unittest.cc"
@@ -314,6 +315,7 @@ set(PARSER_UT_FILES
"testcase/onnx_parser_testcase/message2operator_unittest.cc"
"testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc"
"testcase/tensorflow_parser_testcase/tensorflow_auto_mapping_parser_adapter_unittest.cc"
"testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc"
)

############ libut_parser_common.a ############


+ 48
- 0
tests/ut/parser/graph_builder_utils.cc View File

@@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph_builder_utils.h"

#include "graph/utils/graph_utils.h"

namespace ge {
namespace ut {
NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format,
DataType data_type, std::vector<int64_t> shape) {
auto tensor_desc = std::make_shared<GeTensorDesc>();
tensor_desc->SetShape(GeShape(std::move(shape)));
tensor_desc->SetFormat(format);
tensor_desc->SetDataType(data_type);

auto op_desc = std::make_shared<OpDesc>(name, type);
for (int i = 0; i < in_cnt; ++i) {
op_desc->AddInputDesc(tensor_desc->Clone());
}
for (int i = 0; i < out_cnt; ++i) {
op_desc->AddOutputDesc(tensor_desc->Clone());
}

return graph_->AddNode(op_desc);
}
void GraphBuilder::AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx) {
GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx));
}
void GraphBuilder::AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node) {
GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor());
}

} // namespace ut
} // namespace ge

+ 48
- 0
tests/ut/parser/graph_builder_utils.h View File

@@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_
#define MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_

#include <string>
#include <vector>

#include "graph/compute_graph.h"
#include "graph/graph.h"
#include "graph/node.h"

namespace ge {
namespace ut {
class GraphBuilder {
public:
explicit GraphBuilder(const std::string &name) { graph_ = std::make_shared<ComputeGraph>(name); }
NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt,
Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT,
std::vector<int64_t> shape = {1, 1, 224, 224});
void AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx);
void AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node);
ComputeGraphPtr GetGraph() {
graph_->TopologicalSorting();
return graph_;
}

private:
ComputeGraphPtr graph_;
};
} // namespace ut
} // namespace ge

#endif // MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_

+ 71
- 0
tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc View File

@@ -0,0 +1,71 @@
#include <gtest/gtest.h>
#include <iostream>
#include "graph/utils/attr_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "graph_builder_utils.h"
#include "common/util.h"

#include "tensorflow/iterator_fusion_pass.h"
#include "parser/common/acl_graph_parser_util.h"
#define private public
#include "tensorflow/graph_optimizer.h"
#undef private

namespace ge {
class UtestGraphOptimizer : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};

namespace {
ComputeGraphPtr MakeGraph() {
  ge::ut::GraphBuilder builder("graph");
  std::string name = "graph";
  std::string original_type ;

  original_type = "IteratorV2";
  auto data1 = builder.AddNode(name + "_"+original_type, ge::parser::FRAMEWORKOP, 1, 1);
  ge::AttrUtils::SetStr(data1->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type);

original_type = "IteratorGetNext";
auto data2 = builder.AddNode(name + "_"+original_type+"2", ge::parser::FRAMEWORKOP, 1, 2);
ge::AttrUtils::SetStr(data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type);

string nodefStr;
AttrUtils::SetZeroCopyBytes(
data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF,
Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length()));

original_type = "IteratorGetNext";
auto data3 = builder.AddNode(name + "_"+original_type+"3", ge::parser::FRAMEWORKOP, 2, 1);
ge::AttrUtils::SetStr(data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type);

AttrUtils::SetZeroCopyBytes(
data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF,
Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length()));
builder.AddDataEdge(data1, 0, data2, 0);
builder.AddDataEdge(data2, 0, data3, 0);
builder.AddDataEdge(data2, 1, data3, 1);
return builder.GetGraph();
}
}

TEST_F(UtestGraphOptimizer, graph_optimizer) {
ge::ComputeGraphPtr graph = MakeGraph();
ge::IteratorFusionPass iteratorFusionPass(domi::TENSORFLOW);
EXPECT_NE(iteratorFusionPass.Run(graph),ge::SUCCESS);
}

TEST_F(UtestGraphOptimizer, graph_optimizer_output) {
ge::ComputerGraph graph = MakeGraph();
domi::FrameworkType type = domi::TENSORFLOW;
  ge::ParserGraphOptimizer parserGraphOptimizer(graph,type);
  vector<ge::InDataAnchorPtr> input_anchors;
  vector<ge::OutDataAnchorPtr> output_anchors;
  ge::OpDescPtr fusion_op_desc;
  EXPECT_NE(parserGraphOptimizer.RebuildInputAnchors(input_anchors,fusion_op_desc),ge::SUCCESS);
EXPECT_NE(parserGraphOptimizer.RebuildOutputAnchors(output_anchors,fusion_op_desc),ge::SUCCESS);
}
}

Loading…
Cancel
Save