Browse Source

!9705 [Auto parallel] Add 'InitCostGraph' in rec_algo

From: @xiaoda_zh
Reviewed-by: @stsuteng,@zh_qh
Signed-off-by: @stsuteng
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
8ddb10fd8a
6 changed files with 19 additions and 4 deletions
  1. +10
    -0
      mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc
  2. +1
    -0
      mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h
  3. +5
    -1
      mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
  4. +1
    -1
      mindspore/core/ir/func_graph_cloner.cc
  5. +1
    -1
      mindspore/core/utils/parallel_node_check.cc
  6. +1
    -1
      mindspore/core/utils/parallel_node_check.h

+ 10
- 0
mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc View File

@@ -198,6 +198,16 @@ void CostGraph::SetDeviceMemoryAndCostParameter() {
} }
} }


void CostGraph::Init() {
inputs_tensor_name_list_.clear();
tuple_getitem_list_.clear();
ops_.clear();
edges_.clear();
connected_compoents_.clear();
out_edges_.clear();
in_edges_.clear();
}

void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { void CostGraph::RemoveOperator(const OperatorInfoPtr &op) {
for (auto it = ops_.begin(); it != ops_.end();) { for (auto it = ops_.begin(); it != ops_.end();) {
if ((*it) == op) { if ((*it) == op) {


+ 1
- 0
mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h View File

@@ -61,6 +61,7 @@ class CostGraph {
costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND; costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND;
} }
~CostGraph() = default; ~CostGraph() = default;
void Init();
void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); }
OperatorInfoPtr FindOperatorByIndex(size_t index) { OperatorInfoPtr FindOperatorByIndex(size_t index) {
if (index >= ops_.size()) { if (index >= ops_.size()) {


+ 5
- 1
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc View File

@@ -226,8 +226,11 @@ bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cn
} }


void InitCostGraph() { void InitCostGraph() {
entire_costgraph = std::make_shared<CostGraph>();
if (entire_costgraph == nullptr) {
entire_costgraph = std::make_shared<CostGraph>();
}
entire_costgraph->SetDeviceMemoryAndCostParameter(); entire_costgraph->SetDeviceMemoryAndCostParameter();
entire_costgraph->Init();
} }


OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) {
@@ -974,6 +977,7 @@ void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, co
} }


Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) { Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
InitCostGraph();
if (CostModelContext::GetInstance()->is_multi_subgraphs()) { if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "


+ 1
- 1
mindspore/core/ir/func_graph_cloner.cc View File

@@ -92,7 +92,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
new_node->set_inputs_value(old_node->inputs_value()); new_node->set_inputs_value(old_node->inputs_value());
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_node->set_scope(scope); new_node->set_scope(scope);
if (IsParallelCareCNode(old_node) && new_node->scope() == kDefaultScope) {
if (IsParallelConsiderCNode(old_node) && new_node->scope() == kDefaultScope) {
new_node->set_fullname_with_scope(old_node->fullname_with_scope()); new_node->set_fullname_with_scope(old_node->fullname_with_scope());
} }
new_node->set_kernel_info(old_node->kernel_info_ptr()); new_node->set_kernel_info(old_node->kernel_info_ptr());


+ 1
- 1
mindspore/core/utils/parallel_node_check.cc View File

@@ -37,7 +37,7 @@ bool IsInParallelBlackList(const PrimitivePtr &prim) {
return (PARALLEL_BLACK_LIST_.find(prim->name()) != PARALLEL_BLACK_LIST_.end()); return (PARALLEL_BLACK_LIST_.find(prim->name()) != PARALLEL_BLACK_LIST_.end());
} }


bool IsParallelCareCNode(const CNodePtr &cnode) {
bool IsParallelConsiderCNode(const CNodePtr &cnode) {
if (cnode == nullptr || cnode->size() == 0) { if (cnode == nullptr || cnode->size() == 0) {
return false; return false;
} }


+ 1
- 1
mindspore/core/utils/parallel_node_check.h View File

@@ -21,6 +21,6 @@


namespace mindspore { namespace mindspore {
bool IsInParallelBlackList(const PrimitivePtr &); bool IsInParallelBlackList(const PrimitivePtr &);
bool IsParallelCareCNode(const CNodePtr &);
bool IsParallelConsiderCNode(const CNodePtr &);
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_ #endif // MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_

Loading…
Cancel
Save