From 176c028c65908fbf3ed767fcfb818c427c53638d Mon Sep 17 00:00:00 2001 From: Xiaoda Zhang Date: Wed, 9 Dec 2020 14:51:34 +0800 Subject: [PATCH] add initcostgraph to rec_algo --- .../frontend/parallel/auto_parallel/graph_costmodel.cc | 10 ++++++++++ .../frontend/parallel/auto_parallel/graph_costmodel.h | 1 + .../ccsrc/frontend/parallel/step_auto_parallel.cc | 6 +++++- mindspore/core/ir/func_graph_cloner.cc | 2 +- mindspore/core/utils/parallel_node_check.cc | 2 +- mindspore/core/utils/parallel_node_check.h | 2 +- 6 files changed, 19 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc index 4965751164..86a83284e2 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc @@ -198,6 +198,16 @@ void CostGraph::SetDeviceMemoryAndCostParameter() { } } +void CostGraph::Init() { + inputs_tensor_name_list_.clear(); + tuple_getitem_list_.clear(); + ops_.clear(); + edges_.clear(); + connected_compoents_.clear(); + out_edges_.clear(); + in_edges_.clear(); +} + void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { for (auto it = ops_.begin(); it != ops_.end();) { if ((*it) == op) { diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h index 5721f7f786..2e48af620b 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h @@ -61,6 +61,7 @@ class CostGraph { costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND; } ~CostGraph() = default; + void Init(); void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } OperatorInfoPtr FindOperatorByIndex(size_t index) { if (index >= ops_.size()) { diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index cb347102f8..9cd5e35e45 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -226,8 +226,11 @@ bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cn } void InitCostGraph() { - entire_costgraph = std::make_shared(); + if (entire_costgraph == nullptr) { + entire_costgraph = std::make_shared(); + } entire_costgraph->SetDeviceMemoryAndCostParameter(); + entire_costgraph->Init(); } OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { @@ -974,6 +977,7 @@ void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, co } Status ParallelStrategyRecSearch(const std::vector &all_nodes, const FuncGraphPtr &root) { + InitCostGraph(); if (CostModelContext::GetInstance()->is_multi_subgraphs()) { if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc index d0451e4a34..02331237d0 100644 --- a/mindspore/core/ir/func_graph_cloner.cc +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -92,7 +92,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { new_node->set_inputs_value(old_node->inputs_value()); ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); new_node->set_scope(scope); - if (IsParallelCareCNode(old_node) && new_node->scope() == kDefaultScope) { + if (IsParallelConsiderCNode(old_node) && new_node->scope() == kDefaultScope) { new_node->set_fullname_with_scope(old_node->fullname_with_scope()); } new_node->set_kernel_info(old_node->kernel_info_ptr()); diff --git a/mindspore/core/utils/parallel_node_check.cc b/mindspore/core/utils/parallel_node_check.cc index ac0aa23318..ac0c7e8c5b 100644 --- a/mindspore/core/utils/parallel_node_check.cc +++ b/mindspore/core/utils/parallel_node_check.cc @@ -37,7 +37,7 @@ bool IsInParallelBlackList(const PrimitivePtr &prim) { return (PARALLEL_BLACK_LIST_.find(prim->name()) != PARALLEL_BLACK_LIST_.end()); } -bool IsParallelCareCNode(const CNodePtr &cnode) { +bool IsParallelConsiderCNode(const CNodePtr &cnode) { if (cnode == nullptr || cnode->size() == 0) { return false; } diff --git a/mindspore/core/utils/parallel_node_check.h b/mindspore/core/utils/parallel_node_check.h index 04eafd10c3..491827d15c 100644 --- a/mindspore/core/utils/parallel_node_check.h +++ b/mindspore/core/utils/parallel_node_check.h @@ -21,6 +21,6 @@ namespace mindspore { bool IsInParallelBlackList(const PrimitivePtr &); -bool IsParallelCareCNode(const CNodePtr &); +bool IsParallelConsiderCNode(const CNodePtr &); } // namespace mindspore #endif // MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_