| @@ -35,68 +35,16 @@ namespace ge { | |||
| class ParserGraphOptimizer { | |||
| public: | |||
| explicit ParserGraphOptimizer(ge::ComputeGraphPtr graph, domi::FrameworkType type = domi::TENSORFLOW) | |||
| : graph_(graph), fmktype_(type), local_fmk_op_flag_(false) {} | |||
| : graph_(graph), fmktype_(type) {} | |||
| ~ParserGraphOptimizer() {} | |||
| domi::Status Optimize(); | |||
| domi::Status OptimizeAfterCal(); | |||
| domi::Status FusionFmkop(); | |||
| inline bool IsHCOMOp(const string &op_type) { | |||
| return (op_type == ge::parser::HCOMALLREDUCE) || (op_type == ge::parser::HCOMALLGATHER) || | |||
| (op_type == ge::parser::HCOMBROADCAST) || (op_type == ge::parser::HCOMSEND) || | |||
| (op_type == ge::parser::HCOMRECEIVE) || (op_type == "HcomReduceScatter"); | |||
| } | |||
| void SetLocalFmkopFlag(bool isLocalFmkopFlag) { local_fmk_op_flag_ = isLocalFmkopFlag; } | |||
| const bool GetLocalFmkopFlag() const { return local_fmk_op_flag_; } | |||
| void SetFuncBinPath(std::string isFuncBinPath) { func_bin_path_ = isFuncBinPath; } | |||
| const std::string GetFuncBinPath() const { return func_bin_path_; } | |||
| domi::Status InsertHWCK2FZ(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, | |||
| enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, | |||
| enum ge::Format dstInFormat, enum ge::DataType dstInDatatype); | |||
| domi::Status Insert4DTo5DTransOp(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, | |||
| enum ge::Format src_out_format, enum ge::DataType src_out_data_type, | |||
| enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type); | |||
| domi::Status InsertFZ2HWCK(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, | |||
| enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, | |||
| enum ge::Format dstInFormat, enum ge::DataType dstInDatatype); | |||
| domi::Status Insert5DTo4DTransOp(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, | |||
| enum ge::Format src_out_format, enum ge::DataType src_out_data_type, | |||
| enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type); | |||
| ge::OpDescPtr CreateCastOp(enum ge::DataType input_datatype, enum ge::DataType output_datatype, ge::Format format); | |||
| ge::OpDescPtr CreatePermuteOp(enum ge::Format input_format, enum ge::Format output_format); | |||
| ge::OpDescPtr CreateTransDataOp(enum ge::Format input_format); | |||
| domi::Status NewNodeAddEdges(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, ge::NodePtr first, | |||
| ge::NodePtr second, ge::NodePtr third); | |||
| domi::Status InsertVar5DTo4D(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, | |||
| enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, | |||
| enum ge::Format dstInFormat, enum ge::DataType dstInDatatype); | |||
| ge::OpDescPtr CreateTranslateOp(enum ge::Format inFormat, ge::DataType inDatatype, enum ge::Format outFormat, | |||
| ge::DataType outDatatype); | |||
| private: | |||
| ge::ComputeGraphPtr graph_; | |||
| domi::FrameworkType fmktype_; | |||
| // local fmkop flag | |||
| bool local_fmk_op_flag_; | |||
| std::string func_bin_path_; | |||
| domi::Status FindFmkNodeCluser(unordered_map<string, vector<ge::NodePtr>> &node_cluser_Map); | |||
| domi::Status MarkForFusion(unordered_map<string, vector<ge::NodePtr>> &node_cluser_Map); | |||
| @@ -122,7 +70,6 @@ class ParserGraphOptimizer { | |||
| vector<ge::InControlAnchorPtr> &input_control_anchors, | |||
| vector<ge::OutControlAnchorPtr> &output_control_anchors, ge::NodePtr fusion_node); | |||
| domi::Status MakeTfProtoDef(); | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZER_H_ | |||
| @@ -32,8 +32,6 @@ Status IteratorFusionPass::Run(ge::ComputeGraphPtr graph) { | |||
| REPORT_CALL_ERROR("E19999", "New ParserGraphOptimizer failed"); | |||
| return FAILED; | |||
| } | |||
| graph_optimizer->SetLocalFmkopFlag(local_fmk_op_flag_); | |||
| return graph_optimizer->FusionFmkop(); | |||
| } | |||
| } // namespace ge | |||
| @@ -23,8 +23,8 @@ | |||
| namespace ge { | |||
| class IteratorFusionPass : public GraphPass { | |||
| public: | |||
| IteratorFusionPass(domi::FrameworkType type, bool local_fmk_op_flag) | |||
| : fmk_type_(type), local_fmk_op_flag_(local_fmk_op_flag) {} | |||
| IteratorFusionPass(domi::FrameworkType type) | |||
| : fmk_type_(type) {} | |||
| virtual ~IteratorFusionPass() {} | |||
| @@ -32,7 +32,6 @@ class IteratorFusionPass : public GraphPass { | |||
| private: | |||
| domi::FrameworkType fmk_type_; | |||
| bool local_fmk_op_flag_; | |||
| }; | |||
| } // namespace ge | |||
| @@ -2375,7 +2375,7 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, | |||
| ge::parser::PassManager iterator_fusion_pass; | |||
| try { | |||
| (void)iterator_fusion_pass.AddPass("ParseProto::IteratorFusionPass", | |||
| new ge::IteratorFusionPass(domi::TENSORFLOW, false)); | |||
| new ge::IteratorFusionPass(domi::TENSORFLOW)); | |||
| } catch (std::bad_alloc &e) { | |||
| GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); | |||
| return INTERNAL_ERROR; | |||
| @@ -307,6 +307,7 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) | |||
| set(PARSER_UT_FILES | |||
| "graph_builder_utils.cc" | |||
| "parser_ut_utils.cc" | |||
| "testcase/common/acl_graph_parser_unittest.cc" | |||
| "testcase/onnx_parser_testcase/onnx_parser_unittest.cc" | |||
| @@ -314,6 +315,7 @@ set(PARSER_UT_FILES | |||
| "testcase/onnx_parser_testcase/message2operator_unittest.cc" | |||
| "testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc" | |||
| "testcase/tensorflow_parser_testcase/tensorflow_auto_mapping_parser_adapter_unittest.cc" | |||
| "testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc" | |||
| ) | |||
| ############ libut_parser_common.a ############ | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph_builder_utils.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| namespace ge { | |||
| namespace ut { | |||
| NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, | |||
| DataType data_type, std::vector<int64_t> shape) { | |||
| auto tensor_desc = std::make_shared<GeTensorDesc>(); | |||
| tensor_desc->SetShape(GeShape(std::move(shape))); | |||
| tensor_desc->SetFormat(format); | |||
| tensor_desc->SetDataType(data_type); | |||
| auto op_desc = std::make_shared<OpDesc>(name, type); | |||
| for (int i = 0; i < in_cnt; ++i) { | |||
| op_desc->AddInputDesc(tensor_desc->Clone()); | |||
| } | |||
| for (int i = 0; i < out_cnt; ++i) { | |||
| op_desc->AddOutputDesc(tensor_desc->Clone()); | |||
| } | |||
| return graph_->AddNode(op_desc); | |||
| } | |||
| void GraphBuilder::AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx) { | |||
| GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); | |||
| } | |||
| void GraphBuilder::AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node) { | |||
| GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); | |||
| } | |||
| } // namespace ut | |||
| } // namespace ge | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_ | |||
| #define MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "graph/compute_graph.h" | |||
| #include "graph/graph.h" | |||
| #include "graph/node.h" | |||
| namespace ge { | |||
| namespace ut { | |||
| class GraphBuilder { | |||
| public: | |||
| explicit GraphBuilder(const std::string &name) { graph_ = std::make_shared<ComputeGraph>(name); } | |||
| NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||
| Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, | |||
| std::vector<int64_t> shape = {1, 1, 224, 224}); | |||
| void AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx); | |||
| void AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node); | |||
| ComputeGraphPtr GetGraph() { | |||
| graph_->TopologicalSorting(); | |||
| return graph_; | |||
| } | |||
| private: | |||
| ComputeGraphPtr graph_; | |||
| }; | |||
| } // namespace ut | |||
| } // namespace ge | |||
| #endif // MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_ | |||
| @@ -0,0 +1,71 @@ | |||
| #include <gtest/gtest.h> | |||
| #include <iostream> | |||
| #include "graph/utils/attr_utils.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph_builder_utils.h" | |||
| #include "common/util.h" | |||
| #include "tensorflow/iterator_fusion_pass.h" | |||
| #include "parser/common/acl_graph_parser_util.h" | |||
| #define private public | |||
| #include "tensorflow/graph_optimizer.h" | |||
| #undef private | |||
| namespace ge { | |||
| class UtestGraphOptimizer : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void TearDown() {} | |||
| }; | |||
| namespace { | |||
| ComputeGraphPtr MakeGraph() { | |||
| ge::ut::GraphBuilder builder("graph"); | |||
| std::string name = "graph"; | |||
| std::string original_type ; | |||
| original_type = "IteratorV2"; | |||
| auto data1 = builder.AddNode(name + "_"+original_type, ge::parser::FRAMEWORKOP, 1, 1); | |||
| ge::AttrUtils::SetStr(data1->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); | |||
| original_type = "IteratorGetNext"; | |||
| auto data2 = builder.AddNode(name + "_"+original_type+"2", ge::parser::FRAMEWORKOP, 1, 2); | |||
| ge::AttrUtils::SetStr(data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); | |||
| string nodefStr; | |||
| AttrUtils::SetZeroCopyBytes( | |||
| data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, | |||
| Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length())); | |||
| original_type = "IteratorGetNext"; | |||
| auto data3 = builder.AddNode(name + "_"+original_type+"3", ge::parser::FRAMEWORKOP, 2, 1); | |||
| ge::AttrUtils::SetStr(data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); | |||
| AttrUtils::SetZeroCopyBytes( | |||
| data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, | |||
| Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length())); | |||
| builder.AddDataEdge(data1, 0, data2, 0); | |||
| builder.AddDataEdge(data2, 0, data3, 0); | |||
| builder.AddDataEdge(data2, 1, data3, 1); | |||
| return builder.GetGraph(); | |||
| } | |||
| } | |||
| TEST_F(UtestGraphOptimizer, graph_optimizer) { | |||
| ge::ComputeGraphPtr graph = MakeGraph(); | |||
| ge::IteratorFusionPass iteratorFusionPass(domi::TENSORFLOW); | |||
| EXPECT_NE(iteratorFusionPass.Run(graph),ge::SUCCESS); | |||
| } | |||
| TEST_F(UtestGraphOptimizer, graph_optimizer_output) { | |||
| ge::ComputerGraph graph = MakeGraph(); | |||
| domi::FrameworkType type = domi::TENSORFLOW; | |||
| ge::ParserGraphOptimizer parserGraphOptimizer(graph,type); | |||
| vector<ge::InDataAnchorPtr> input_anchors; | |||
| vector<ge::OutDataAnchorPtr> output_anchors; | |||
| ge::OpDescPtr fusion_op_desc; | |||
| EXPECT_NE(parserGraphOptimizer.RebuildInputAnchors(input_anchors,fusion_op_desc),ge::SUCCESS); | |||
| EXPECT_NE(parserGraphOptimizer.RebuildOutputAnchors(output_anchors,fusion_op_desc),ge::SUCCESS); | |||
| } | |||
| } | |||