|
|
|
@@ -226,8 +226,11 @@ bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cn |
|
|
|
} |
|
|
|
|
|
|
|
void InitCostGraph() { |
|
|
|
entire_costgraph = std::make_shared<CostGraph>(); |
|
|
|
if (entire_costgraph == nullptr) { |
|
|
|
entire_costgraph = std::make_shared<CostGraph>(); |
|
|
|
} |
|
|
|
entire_costgraph->SetDeviceMemoryAndCostParameter(); |
|
|
|
entire_costgraph->Init(); |
|
|
|
} |
|
|
|
|
|
|
|
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { |
|
|
|
@@ -974,6 +977,7 @@ void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, co |
|
|
|
} |
|
|
|
|
|
|
|
Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) { |
|
|
|
InitCostGraph(); |
|
|
|
if (CostModelContext::GetInstance()->is_multi_subgraphs()) { |
|
|
|
if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { |
|
|
|
MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " |
|
|
|
|