| @@ -30,11 +30,9 @@ | |||
| #include "utils/ms_context.h" | |||
| #include "utils/file_utils.h" | |||
| #include "utils/context/graph_kernel_flags.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "backend/optimizer/pass/getitem_tuple.h" | |||
| #include "backend/optimizer/graph_kernel/update_state_formatter.h" | |||
| namespace mindspore::graphkernel { | |||
| std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() { | |||
| @@ -109,14 +107,6 @@ std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() { | |||
| return clusterable_ops; | |||
| } | |||
| namespace { | |||
| size_t CountGraphKernelInnerNodes(const AnfNodePtr &node) { | |||
| AnfNodePtrList node_list; | |||
| kernel::GetValidKernelNodes(AnfAlgo::GetCNodeFuncGraphPtr(node), &node_list); | |||
| return node_list.size(); | |||
| } | |||
| } // namespace | |||
| bool GraphKernelCluster::IsClusterableOp(const AnfNodePtr &node) { | |||
| if (AnfAlgo::IsGraphKernel(node)) { | |||
| return true; | |||
| @@ -142,19 +132,12 @@ class Graph { | |||
| struct Cluster { | |||
| size_t cluster_id_; // node_id of the representative. | |||
| size_t cluster_size_{1}; // size of cluster, composite node is considered as one node. | |||
| size_t basic_op_cnt_{1}; // basic node count, the inner nodes of composite node are counted. | |||
| std::set<size_t> inputs_; // inputs' cluster_id. | |||
| size_t seed_{0}; // visited flag of dfs. | |||
| size_t max_node_id_; // largest node id of a cluster | |||
| Cluster(size_t node_id, const AnfNodePtr &node, const std::unordered_map<AnfNodePtr, size_t> &node_idx_map) | |||
| : cluster_id_(node_id), max_node_id_(node_id) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { | |||
| basic_op_cnt_ = 0; | |||
| } else if (AnfAlgo::IsGraphKernel(node)) { | |||
| // the basic_op_cnt_ is used to limit the composite op size | |||
| basic_op_cnt_ = CountGraphKernelInnerNodes(node); | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (const auto &inp : cnode->inputs()) { | |||
| @@ -171,7 +154,6 @@ class Graph { | |||
| other_cluster->cluster_id_ = cluster_id_; | |||
| max_node_id_ = std::max(other_cluster->max_node_id_, max_node_id_); | |||
| cluster_size_ += other_cluster->cluster_size_; | |||
| basic_op_cnt_ += other_cluster->basic_op_cnt_; | |||
| (void)std::for_each(other_cluster->inputs_.begin(), other_cluster->inputs_.end(), | |||
| [this](size_t inp) { (void)this->inputs_.insert(inp); }); | |||
| other_cluster->Clean(); | |||
| @@ -181,7 +163,6 @@ class Graph { | |||
| void Clean() { | |||
| inputs_.clear(); | |||
| cluster_size_ = 0; | |||
| basic_op_cnt_ = 0; | |||
| max_node_id_ = 0; | |||
| } | |||
| }; // struct Cluster | |||
| @@ -232,9 +213,6 @@ class Graph { | |||
| // Get cluster size | |||
| size_t GetSize(size_t cluster_id) { return clusters_[Find(cluster_id)].cluster_size_; } | |||
| // Get cluster's basic op count | |||
| size_t GetBasicNodeCount(size_t cluster_id) { return clusters_[Find(cluster_id)].basic_op_cnt_; } | |||
| // Get cluster's inputs | |||
| const std::set<size_t> &GetInputs(size_t cluster_id) { | |||
| cluster_id = Find(cluster_id); | |||
| @@ -505,7 +483,6 @@ void GraphKernelCluster::Init(const FuncGraphPtr &func_graph) { | |||
| } | |||
| bool GraphKernelCluster::Run(const FuncGraphPtr &func_graph) { | |||
| (void)std::make_shared<ShrinkUpdateState>()->Run(func_graph); | |||
| auto mng = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| Init(func_graph); | |||
| @@ -518,7 +495,6 @@ bool GraphKernelCluster::Run(const FuncGraphPtr &func_graph) { | |||
| mng->KeepRoots({func_graph}); | |||
| } | |||
| Clean(); | |||
| (void)std::make_shared<SpreadUpdateState>()->Run(func_graph); | |||
| return changed; | |||
| } | |||
| } // namespace mindspore::graphkernel | |||