|
|
|
@@ -233,7 +233,8 @@ void InitCostGraph() { |
|
|
|
entire_costgraph->Init(); |
|
|
|
} |
|
|
|
|
|
|
|
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { |
|
|
|
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes, |
|
|
|
StrategyMap *stra_map) { |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
auto attrs = prim->attrs(); |
|
|
|
@@ -290,7 +291,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & |
|
|
|
// If no strategy has been configured for this operator, then candidate strategies are generated for |
|
|
|
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy. |
|
|
|
// if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint . |
|
|
|
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) { |
|
|
|
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt && !is_last_nodes) { |
|
|
|
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for |
|
|
|
// BatchParallelInfo operator |
|
|
|
operator_info->ComputeBatchSplitFlagList(); |
|
|
|
@@ -307,10 +308,16 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & |
|
|
|
} else { |
|
|
|
// In this case, the configured strategy should be extracted to help setting cost |
|
|
|
StrategyPtr strategyPtr; |
|
|
|
if (load_strategy_from_ckpt) { |
|
|
|
strategyPtr = (*stra_map)[strategy_key_name]; |
|
|
|
} else { |
|
|
|
if (is_last_nodes) { |
|
|
|
bool full_batch = ParallelContext::GetInstance()->full_batch(); |
|
|
|
strategyPtr = GenerateBatchParallelStrategy(operator_info, prim); |
|
|
|
if (full_batch) { |
|
|
|
SetLastNodeStrategy(strategyPtr); |
|
|
|
} |
|
|
|
} else if (StrategyFound(attrs)) { |
|
|
|
strategyPtr = parallel::ExtractStrategy(attrs); |
|
|
|
} else { |
|
|
|
strategyPtr = (*stra_map)[strategy_key_name]; |
|
|
|
} |
|
|
|
if (strategyPtr != nullptr) { |
|
|
|
if (prim->name() == RESHAPE) { |
|
|
|
@@ -341,8 +348,10 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & |
|
|
|
} |
|
|
|
|
|
|
|
// Using CNode's UniqueIds to construct nodes |
|
|
|
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) { |
|
|
|
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) { |
|
|
|
MS_LOG(INFO) << "Constructing nodes for cost graph begins."; |
|
|
|
entire_costgraph = std::make_shared<CostGraph>(); |
|
|
|
entire_costgraph->SetDeviceMemoryAndCostParameter(); |
|
|
|
// The map from CNode's UniqueId to its operatorInfo |
|
|
|
std::map<std::string, OperatorInfoPtr> from_cnode_to_info; |
|
|
|
// The operator_infos in a loop |
|
|
|
@@ -356,7 +365,12 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node |
|
|
|
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<std::string> last_forward_node_ids; |
|
|
|
if (!root->has_flag(TRAINING)) { |
|
|
|
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids); |
|
|
|
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; |
|
|
|
} |
|
|
|
// Step 1 |
|
|
|
for (auto &node : all_nodes) { |
|
|
|
// NOTE: we only care about splittable Primitive operators |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
@@ -401,7 +415,9 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node |
|
|
|
(void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), current_op_ptr)); |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map); |
|
|
|
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != |
|
|
|
last_forward_node_ids.end(); |
|
|
|
auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map); |
|
|
|
if (operator_info == nullptr) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -436,8 +452,10 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node |
|
|
|
} |
|
|
|
|
|
|
|
// Using CNode's UniqueIdThroughCopys to construct nodes |
|
|
|
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) { |
|
|
|
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) { |
|
|
|
MS_LOG(INFO) << "Constructing nodes for cost graph begins."; |
|
|
|
entire_costgraph = std::make_shared<CostGraph>(); |
|
|
|
entire_costgraph->SetDeviceMemoryAndCostParameter(); |
|
|
|
// The map from CNode's UniqueIdThroughCopy to its operatorInfo |
|
|
|
std::map<std::string, OperatorInfoPtr> from_cnode_to_info; |
|
|
|
// The operator_infos in a loop |
|
|
|
@@ -451,6 +469,11 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no |
|
|
|
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; |
|
|
|
} |
|
|
|
} |
|
|
|
std::vector<std::string> last_forward_node_ids; |
|
|
|
if (!root->has_flag(TRAINING)) { |
|
|
|
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids); |
|
|
|
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; |
|
|
|
} |
|
|
|
for (auto &node : all_nodes) { |
|
|
|
// NOTE: we only care about splittable Primitive operators |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
@@ -496,7 +519,9 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no |
|
|
|
continue; |
|
|
|
} |
|
|
|
// In this case, the corresponding OperatorInfo is not created, create the new one. |
|
|
|
auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map); |
|
|
|
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != |
|
|
|
last_forward_node_ids.end(); |
|
|
|
auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map); |
|
|
|
if (operator_info == nullptr) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|