| @@ -35,68 +35,16 @@ namespace ge { | |||||
| class ParserGraphOptimizer { | class ParserGraphOptimizer { | ||||
| public: | public: | ||||
| explicit ParserGraphOptimizer(ge::ComputeGraphPtr graph, domi::FrameworkType type = domi::TENSORFLOW) | explicit ParserGraphOptimizer(ge::ComputeGraphPtr graph, domi::FrameworkType type = domi::TENSORFLOW) | ||||
| : graph_(graph), fmktype_(type), local_fmk_op_flag_(false) {} | |||||
| : graph_(graph), fmktype_(type) {} | |||||
| ~ParserGraphOptimizer() {} | ~ParserGraphOptimizer() {} | ||||
| domi::Status Optimize(); | |||||
| domi::Status OptimizeAfterCal(); | |||||
| domi::Status FusionFmkop(); | domi::Status FusionFmkop(); | ||||
| inline bool IsHCOMOp(const string &op_type) { | |||||
| return (op_type == ge::parser::HCOMALLREDUCE) || (op_type == ge::parser::HCOMALLGATHER) || | |||||
| (op_type == ge::parser::HCOMBROADCAST) || (op_type == ge::parser::HCOMSEND) || | |||||
| (op_type == ge::parser::HCOMRECEIVE) || (op_type == "HcomReduceScatter"); | |||||
| } | |||||
| void SetLocalFmkopFlag(bool isLocalFmkopFlag) { local_fmk_op_flag_ = isLocalFmkopFlag; } | |||||
| const bool GetLocalFmkopFlag() const { return local_fmk_op_flag_; } | |||||
| void SetFuncBinPath(std::string isFuncBinPath) { func_bin_path_ = isFuncBinPath; } | |||||
| const std::string GetFuncBinPath() const { return func_bin_path_; } | |||||
| domi::Status InsertHWCK2FZ(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, | |||||
| enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, | |||||
| enum ge::Format dstInFormat, enum ge::DataType dstInDatatype); | |||||
| domi::Status Insert4DTo5DTransOp(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, | |||||
| enum ge::Format src_out_format, enum ge::DataType src_out_data_type, | |||||
| enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type); | |||||
| domi::Status InsertFZ2HWCK(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, | |||||
| enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, | |||||
| enum ge::Format dstInFormat, enum ge::DataType dstInDatatype); | |||||
| domi::Status Insert5DTo4DTransOp(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, | |||||
| enum ge::Format src_out_format, enum ge::DataType src_out_data_type, | |||||
| enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type); | |||||
| ge::OpDescPtr CreateCastOp(enum ge::DataType input_datatype, enum ge::DataType output_datatype, ge::Format format); | |||||
| ge::OpDescPtr CreatePermuteOp(enum ge::Format input_format, enum ge::Format output_format); | |||||
| ge::OpDescPtr CreateTransDataOp(enum ge::Format input_format); | |||||
| domi::Status NewNodeAddEdges(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, ge::NodePtr first, | |||||
| ge::NodePtr second, ge::NodePtr third); | |||||
| domi::Status InsertVar5DTo4D(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, | |||||
| enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, | |||||
| enum ge::Format dstInFormat, enum ge::DataType dstInDatatype); | |||||
| ge::OpDescPtr CreateTranslateOp(enum ge::Format inFormat, ge::DataType inDatatype, enum ge::Format outFormat, | |||||
| ge::DataType outDatatype); | |||||
| private: | private: | ||||
| ge::ComputeGraphPtr graph_; | ge::ComputeGraphPtr graph_; | ||||
| domi::FrameworkType fmktype_; | domi::FrameworkType fmktype_; | ||||
| // local fmkop flag | |||||
| bool local_fmk_op_flag_; | |||||
| std::string func_bin_path_; | |||||
| domi::Status FindFmkNodeCluser(unordered_map<string, vector<ge::NodePtr>> &node_cluser_Map); | domi::Status FindFmkNodeCluser(unordered_map<string, vector<ge::NodePtr>> &node_cluser_Map); | ||||
| domi::Status MarkForFusion(unordered_map<string, vector<ge::NodePtr>> &node_cluser_Map); | domi::Status MarkForFusion(unordered_map<string, vector<ge::NodePtr>> &node_cluser_Map); | ||||
| @@ -122,7 +70,6 @@ class ParserGraphOptimizer { | |||||
| vector<ge::InControlAnchorPtr> &input_control_anchors, | vector<ge::InControlAnchorPtr> &input_control_anchors, | ||||
| vector<ge::OutControlAnchorPtr> &output_control_anchors, ge::NodePtr fusion_node); | vector<ge::OutControlAnchorPtr> &output_control_anchors, ge::NodePtr fusion_node); | ||||
| domi::Status MakeTfProtoDef(); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZER_H_ | #endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZER_H_ | ||||
| @@ -32,8 +32,6 @@ Status IteratorFusionPass::Run(ge::ComputeGraphPtr graph) { | |||||
| REPORT_CALL_ERROR("E19999", "New ParserGraphOptimizer failed"); | REPORT_CALL_ERROR("E19999", "New ParserGraphOptimizer failed"); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| graph_optimizer->SetLocalFmkopFlag(local_fmk_op_flag_); | |||||
| return graph_optimizer->FusionFmkop(); | return graph_optimizer->FusionFmkop(); | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -23,8 +23,8 @@ | |||||
| namespace ge { | namespace ge { | ||||
| class IteratorFusionPass : public GraphPass { | class IteratorFusionPass : public GraphPass { | ||||
| public: | public: | ||||
| IteratorFusionPass(domi::FrameworkType type, bool local_fmk_op_flag) | |||||
| : fmk_type_(type), local_fmk_op_flag_(local_fmk_op_flag) {} | |||||
| IteratorFusionPass(domi::FrameworkType type) | |||||
| : fmk_type_(type) {} | |||||
| virtual ~IteratorFusionPass() {} | virtual ~IteratorFusionPass() {} | ||||
| @@ -32,7 +32,6 @@ class IteratorFusionPass : public GraphPass { | |||||
| private: | private: | ||||
| domi::FrameworkType fmk_type_; | domi::FrameworkType fmk_type_; | ||||
| bool local_fmk_op_flag_; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -2375,7 +2375,7 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, | |||||
| ge::parser::PassManager iterator_fusion_pass; | ge::parser::PassManager iterator_fusion_pass; | ||||
| try { | try { | ||||
| (void)iterator_fusion_pass.AddPass("ParseProto::IteratorFusionPass", | (void)iterator_fusion_pass.AddPass("ParseProto::IteratorFusionPass", | ||||
| new ge::IteratorFusionPass(domi::TENSORFLOW, false)); | |||||
| new ge::IteratorFusionPass(domi::TENSORFLOW)); | |||||
| } catch (std::bad_alloc &e) { | } catch (std::bad_alloc &e) { | ||||
| GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); | GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| @@ -307,6 +307,7 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) | |||||
| set(PARSER_UT_FILES | set(PARSER_UT_FILES | ||||
| "graph_builder_utils.cc" | |||||
| "parser_ut_utils.cc" | "parser_ut_utils.cc" | ||||
| "testcase/common/acl_graph_parser_unittest.cc" | "testcase/common/acl_graph_parser_unittest.cc" | ||||
| "testcase/onnx_parser_testcase/onnx_parser_unittest.cc" | "testcase/onnx_parser_testcase/onnx_parser_unittest.cc" | ||||
| @@ -314,6 +315,7 @@ set(PARSER_UT_FILES | |||||
| "testcase/onnx_parser_testcase/message2operator_unittest.cc" | "testcase/onnx_parser_testcase/message2operator_unittest.cc" | ||||
| "testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc" | "testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc" | ||||
| "testcase/tensorflow_parser_testcase/tensorflow_auto_mapping_parser_adapter_unittest.cc" | "testcase/tensorflow_parser_testcase/tensorflow_auto_mapping_parser_adapter_unittest.cc" | ||||
| "testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc" | |||||
| ) | ) | ||||
| ############ libut_parser_common.a ############ | ############ libut_parser_common.a ############ | ||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph_builder_utils.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| namespace ge { | |||||
| namespace ut { | |||||
| NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, | |||||
| DataType data_type, std::vector<int64_t> shape) { | |||||
| auto tensor_desc = std::make_shared<GeTensorDesc>(); | |||||
| tensor_desc->SetShape(GeShape(std::move(shape))); | |||||
| tensor_desc->SetFormat(format); | |||||
| tensor_desc->SetDataType(data_type); | |||||
| auto op_desc = std::make_shared<OpDesc>(name, type); | |||||
| for (int i = 0; i < in_cnt; ++i) { | |||||
| op_desc->AddInputDesc(tensor_desc->Clone()); | |||||
| } | |||||
| for (int i = 0; i < out_cnt; ++i) { | |||||
| op_desc->AddOutputDesc(tensor_desc->Clone()); | |||||
| } | |||||
| return graph_->AddNode(op_desc); | |||||
| } | |||||
| void GraphBuilder::AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx) { | |||||
| GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); | |||||
| } | |||||
| void GraphBuilder::AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node) { | |||||
| GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); | |||||
| } | |||||
| } // namespace ut | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_ | |||||
| #define MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/graph.h" | |||||
| #include "graph/node.h" | |||||
| namespace ge { | |||||
| namespace ut { | |||||
| class GraphBuilder { | |||||
| public: | |||||
| explicit GraphBuilder(const std::string &name) { graph_ = std::make_shared<ComputeGraph>(name); } | |||||
| NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||||
| Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, | |||||
| std::vector<int64_t> shape = {1, 1, 224, 224}); | |||||
| void AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx); | |||||
| void AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node); | |||||
| ComputeGraphPtr GetGraph() { | |||||
| graph_->TopologicalSorting(); | |||||
| return graph_; | |||||
| } | |||||
| private: | |||||
| ComputeGraphPtr graph_; | |||||
| }; | |||||
| } // namespace ut | |||||
| } // namespace ge | |||||
| #endif // MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_ | |||||
| @@ -0,0 +1,71 @@ | |||||
| #include <gtest/gtest.h> | |||||
| #include <iostream> | |||||
| #include "graph/utils/attr_utils.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph_builder_utils.h" | |||||
| #include "common/util.h" | |||||
| #include "tensorflow/iterator_fusion_pass.h" | |||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #define private public | |||||
| #include "tensorflow/graph_optimizer.h" | |||||
| #undef private | |||||
| namespace ge { | |||||
| class UtestGraphOptimizer : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| namespace { | |||||
| ComputeGraphPtr MakeGraph() { | |||||
| ge::ut::GraphBuilder builder("graph"); | |||||
| std::string name = "graph"; | |||||
| std::string original_type ; | |||||
| original_type = "IteratorV2"; | |||||
| auto data1 = builder.AddNode(name + "_"+original_type, ge::parser::FRAMEWORKOP, 1, 1); | |||||
| ge::AttrUtils::SetStr(data1->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); | |||||
| original_type = "IteratorGetNext"; | |||||
| auto data2 = builder.AddNode(name + "_"+original_type+"2", ge::parser::FRAMEWORKOP, 1, 2); | |||||
| ge::AttrUtils::SetStr(data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); | |||||
| string nodefStr; | |||||
| AttrUtils::SetZeroCopyBytes( | |||||
| data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, | |||||
| Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length())); | |||||
| original_type = "IteratorGetNext"; | |||||
| auto data3 = builder.AddNode(name + "_"+original_type+"3", ge::parser::FRAMEWORKOP, 2, 1); | |||||
| ge::AttrUtils::SetStr(data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); | |||||
| AttrUtils::SetZeroCopyBytes( | |||||
| data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, | |||||
| Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length())); | |||||
| builder.AddDataEdge(data1, 0, data2, 0); | |||||
| builder.AddDataEdge(data2, 0, data3, 0); | |||||
| builder.AddDataEdge(data2, 1, data3, 1); | |||||
| return builder.GetGraph(); | |||||
| } | |||||
| } | |||||
| TEST_F(UtestGraphOptimizer, graph_optimizer) { | |||||
| ge::ComputeGraphPtr graph = MakeGraph(); | |||||
| ge::IteratorFusionPass iteratorFusionPass(domi::TENSORFLOW); | |||||
| EXPECT_NE(iteratorFusionPass.Run(graph),ge::SUCCESS); | |||||
| } | |||||
| TEST_F(UtestGraphOptimizer, graph_optimizer_output) { | |||||
| ge::ComputerGraph graph = MakeGraph(); | |||||
| domi::FrameworkType type = domi::TENSORFLOW; | |||||
| ge::ParserGraphOptimizer parserGraphOptimizer(graph,type); | |||||
| vector<ge::InDataAnchorPtr> input_anchors; | |||||
| vector<ge::OutDataAnchorPtr> output_anchors; | |||||
| ge::OpDescPtr fusion_op_desc; | |||||
| EXPECT_NE(parserGraphOptimizer.RebuildInputAnchors(input_anchors,fusion_op_desc),ge::SUCCESS); | |||||
| EXPECT_NE(parserGraphOptimizer.RebuildOutputAnchors(output_anchors,fusion_op_desc),ge::SUCCESS); | |||||
| } | |||||
| } | |||||