From: @xiaoda_zh Reviewed-by: @stsuteng,@zh_qh Signed-off-by: @stsutengtags/v1.1.0
| @@ -198,6 +198,16 @@ void CostGraph::SetDeviceMemoryAndCostParameter() { | |||||
| } | } | ||||
| } | } | ||||
| void CostGraph::Init() { | |||||
| inputs_tensor_name_list_.clear(); | |||||
| tuple_getitem_list_.clear(); | |||||
| ops_.clear(); | |||||
| edges_.clear(); | |||||
| connected_compoents_.clear(); | |||||
| out_edges_.clear(); | |||||
| in_edges_.clear(); | |||||
| } | |||||
| void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { | void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { | ||||
| for (auto it = ops_.begin(); it != ops_.end();) { | for (auto it = ops_.begin(); it != ops_.end();) { | ||||
| if ((*it) == op) { | if ((*it) == op) { | ||||
| @@ -61,6 +61,7 @@ class CostGraph { | |||||
| costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND; | costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND; | ||||
| } | } | ||||
| ~CostGraph() = default; | ~CostGraph() = default; | ||||
| void Init(); | |||||
| void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } | void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } | ||||
| OperatorInfoPtr FindOperatorByIndex(size_t index) { | OperatorInfoPtr FindOperatorByIndex(size_t index) { | ||||
| if (index >= ops_.size()) { | if (index >= ops_.size()) { | ||||
| @@ -226,8 +226,11 @@ bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cn | |||||
| } | } | ||||
| void InitCostGraph() { | void InitCostGraph() { | ||||
| entire_costgraph = std::make_shared<CostGraph>(); | |||||
| if (entire_costgraph == nullptr) { | |||||
| entire_costgraph = std::make_shared<CostGraph>(); | |||||
| } | |||||
| entire_costgraph->SetDeviceMemoryAndCostParameter(); | entire_costgraph->SetDeviceMemoryAndCostParameter(); | ||||
| entire_costgraph->Init(); | |||||
| } | } | ||||
| OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { | OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { | ||||
| @@ -974,6 +977,7 @@ void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, co | |||||
| } | } | ||||
| Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) { | Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) { | ||||
| InitCostGraph(); | |||||
| if (CostModelContext::GetInstance()->is_multi_subgraphs()) { | if (CostModelContext::GetInstance()->is_multi_subgraphs()) { | ||||
| if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { | if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { | ||||
| MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " | MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " | ||||
| @@ -92,7 +92,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { | |||||
| new_node->set_inputs_value(old_node->inputs_value()); | new_node->set_inputs_value(old_node->inputs_value()); | ||||
| ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); | ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); | ||||
| new_node->set_scope(scope); | new_node->set_scope(scope); | ||||
| if (IsParallelCareCNode(old_node) && new_node->scope() == kDefaultScope) { | |||||
| if (IsParallelConsiderCNode(old_node) && new_node->scope() == kDefaultScope) { | |||||
| new_node->set_fullname_with_scope(old_node->fullname_with_scope()); | new_node->set_fullname_with_scope(old_node->fullname_with_scope()); | ||||
| } | } | ||||
| new_node->set_kernel_info(old_node->kernel_info_ptr()); | new_node->set_kernel_info(old_node->kernel_info_ptr()); | ||||
| @@ -37,7 +37,7 @@ bool IsInParallelBlackList(const PrimitivePtr &prim) { | |||||
| return (PARALLEL_BLACK_LIST_.find(prim->name()) != PARALLEL_BLACK_LIST_.end()); | return (PARALLEL_BLACK_LIST_.find(prim->name()) != PARALLEL_BLACK_LIST_.end()); | ||||
| } | } | ||||
| bool IsParallelCareCNode(const CNodePtr &cnode) { | |||||
| bool IsParallelConsiderCNode(const CNodePtr &cnode) { | |||||
| if (cnode == nullptr || cnode->size() == 0) { | if (cnode == nullptr || cnode->size() == 0) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -21,6 +21,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| bool IsInParallelBlackList(const PrimitivePtr &); | bool IsInParallelBlackList(const PrimitivePtr &); | ||||
| bool IsParallelCareCNode(const CNodePtr &); | |||||
| bool IsParallelConsiderCNode(const CNodePtr &); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_ | #endif // MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_ | ||||