|
|
|
@@ -234,6 +234,48 @@ void InitCostGraph() { |
|
|
|
entire_costgraph->Init(); |
|
|
|
} |
|
|
|
|
|
|
|
void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const PrimitivePtr &prim, |
|
|
|
const std::unordered_map<std::string, ValuePtr> &attrs, bool is_last_nodes, |
|
|
|
StrategyMap *stra_map, const std::string &strategy_key_name) { |
|
|
|
// In this case, the configured strategy should be extracted to help setting cost |
|
|
|
StrategyPtr strategyPtr; |
|
|
|
if (is_last_nodes) { |
|
|
|
bool full_batch = ParallelContext::GetInstance()->full_batch(); |
|
|
|
strategyPtr = GenerateBatchParallelStrategy(operator_info, prim); |
|
|
|
if (full_batch) { |
|
|
|
SetLastNodeStrategy(strategyPtr); |
|
|
|
} |
|
|
|
} else if (StrategyFound(attrs)) { |
|
|
|
strategyPtr = parallel::ExtractStrategy(attrs); |
|
|
|
} else { |
|
|
|
strategyPtr = (*stra_map)[strategy_key_name]; |
|
|
|
} |
|
|
|
if (strategyPtr != nullptr) { |
|
|
|
if (prim->name() == RESHAPE) { |
|
|
|
MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; |
|
|
|
} |
|
|
|
// Set cost for this configured strategy |
|
|
|
if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed"; |
|
|
|
} else if (FULLY_USE_DEVICES) { |
|
|
|
// If configured to fully use devices, then checking for the user-specified strategy |
|
|
|
int64_t used_devices = operator_info->used_devices(); |
|
|
|
MS_EXCEPTION_IF_NULL(g_device_manager); |
|
|
|
auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size(); |
|
|
|
// 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel |
|
|
|
if (used_devices == 1) { |
|
|
|
return; |
|
|
|
} |
|
|
|
// 'used_devices == -1' means that 'used_devices_' is not set |
|
|
|
if ((used_devices == -1) || LongToSize(used_devices) != total_device_num) { |
|
|
|
MS_LOG(EXCEPTION) << "In configuration 'FULLY_USE_DEVICES' = True, " |
|
|
|
<< "but the specified strategy uses device: " << used_devices |
|
|
|
<< ", total devices: " << total_device_num; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes, |
|
|
|
StrategyMap *stra_map) { |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
@@ -290,9 +332,10 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & |
|
|
|
bool load_strategy_from_ckpt = |
|
|
|
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); |
|
|
|
// If no strategy has been configured for this operator, then candidate strategies are generated for |
|
|
|
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy. |
|
|
|
// if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint . |
|
|
|
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt && !is_last_nodes) { |
|
|
|
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy; |
|
|
|
// if strategy is set to load from checkpoint, it is preferred to load strategy from checkpoint. |
|
|
|
bool is_gen_stra = (!StrategyFound(attrs) || prim->name() == CAST) && (!load_strategy_from_ckpt) && (!is_last_nodes); |
|
|
|
if (is_gen_stra) { |
|
|
|
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for |
|
|
|
// BatchParallelInfo operator |
|
|
|
operator_info->ComputeBatchSplitFlagList(); |
|
|
|
@@ -307,43 +350,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & |
|
|
|
MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name(); |
|
|
|
} |
|
|
|
} else { |
|
|
|
// In this case, the configured strategy should be extracted to help setting cost |
|
|
|
StrategyPtr strategyPtr; |
|
|
|
if (is_last_nodes) { |
|
|
|
bool full_batch = ParallelContext::GetInstance()->full_batch(); |
|
|
|
strategyPtr = GenerateBatchParallelStrategy(operator_info, prim); |
|
|
|
if (full_batch) { |
|
|
|
SetLastNodeStrategy(strategyPtr); |
|
|
|
} |
|
|
|
} else if (StrategyFound(attrs)) { |
|
|
|
strategyPtr = parallel::ExtractStrategy(attrs); |
|
|
|
} else { |
|
|
|
strategyPtr = (*stra_map)[strategy_key_name]; |
|
|
|
} |
|
|
|
if (strategyPtr != nullptr) { |
|
|
|
if (prim->name() == RESHAPE) { |
|
|
|
MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; |
|
|
|
} |
|
|
|
// Set cost for this configured strategy |
|
|
|
if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed"; |
|
|
|
} else if (FULLY_USE_DEVICES) { |
|
|
|
// If configured to fully use devices, then checking for the user-specified strategy |
|
|
|
int64_t used_devices = operator_info->used_devices(); |
|
|
|
MS_EXCEPTION_IF_NULL(g_device_manager); |
|
|
|
auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size(); |
|
|
|
// 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel |
|
|
|
if (used_devices == 1) { |
|
|
|
return operator_info; |
|
|
|
} |
|
|
|
// 'used_devices == -1' means that 'used_devices_' is not set |
|
|
|
if ((used_devices == -1) || LongToSize(used_devices) != total_device_num) { |
|
|
|
MS_LOG(EXCEPTION) << "In configuration 'FULLY_USE_DEVICES' = True, " |
|
|
|
<< "but the specified strategy uses device: " << used_devices |
|
|
|
<< ", total devices: " << total_device_num; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
SetStrategyToOperator(operator_info, prim, attrs, is_last_nodes, stra_map, strategy_key_name); |
|
|
|
} |
|
|
|
return operator_info; |
|
|
|
} |
|
|
|
@@ -451,6 +458,29 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
void SetOperatorToCNode(const OperatorInfoPtr ¤t_op_ptr, const PrimitivePtr &prim, const CNodePtr &cnode) { |
|
|
|
if (current_op_ptr == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed."; |
|
|
|
} else { |
|
|
|
bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && |
|
|
|
(current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && |
|
|
|
(current_op_ptr->name().find(prim->name()) == std::string::npos); |
|
|
|
if (is_find_wrong) { |
|
|
|
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() |
|
|
|
<< " does not match the Prim: " << prim->name(); |
|
|
|
} |
|
|
|
|
|
|
|
// Needed by rec_parser |
|
|
|
ModifyInputsTensorNameListIfOperatorInfoCreated(current_op_ptr->name(), cnode->UniqueId()); |
|
|
|
|
|
|
|
cnode->set_user_data<OperatorInfo>(current_op_ptr); |
|
|
|
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() |
|
|
|
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() |
|
|
|
<< ", CNode fullname_with_scope: " << cnode->fullname_with_scope() |
|
|
|
<< " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Using CNode's UniqueIdThroughCopys to construct nodes |
|
|
|
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) { |
|
|
|
MS_LOG(INFO) << "Constructing nodes for cost graph begins."; |
|
|
|
@@ -497,7 +527,8 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no |
|
|
|
if (search_cnode == from_cnode_to_info.end()) { |
|
|
|
size_t loop_index = 0; |
|
|
|
bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index); |
|
|
|
if (DP_ALGO_SINGLE_LOOP && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) { |
|
|
|
bool is_op_created = DP_ALGO_SINGLE_LOOP && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size()); |
|
|
|
if (is_op_created) { |
|
|
|
const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]]; |
|
|
|
bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && |
|
|
|
(current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && |
|
|
|
@@ -543,27 +574,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no |
|
|
|
// Needed by rec_parser |
|
|
|
entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); |
|
|
|
} else { |
|
|
|
auto current_op_ptr = search_cnode->second; |
|
|
|
if (current_op_ptr == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed."; |
|
|
|
} else { |
|
|
|
bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && |
|
|
|
(current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && |
|
|
|
(current_op_ptr->name().find(prim->name()) == std::string::npos); |
|
|
|
if (is_find_wrong) { |
|
|
|
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() |
|
|
|
<< " does not match the Prim: " << prim->name(); |
|
|
|
} |
|
|
|
|
|
|
|
// Needed by rec_parser |
|
|
|
ModifyInputsTensorNameListIfOperatorInfoCreated(current_op_ptr->name(), cnode->UniqueId()); |
|
|
|
|
|
|
|
cnode->set_user_data<OperatorInfo>(current_op_ptr); |
|
|
|
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() |
|
|
|
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() |
|
|
|
<< ", CNode fullname_with_scope: " << cnode->fullname_with_scope() |
|
|
|
<< " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); |
|
|
|
} |
|
|
|
SetOperatorToCNode(search_cnode->second, prim, cnode); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -571,6 +582,46 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const OperatorInfoPtr &node_op_info, |
|
|
|
const CNodePtr &cnode, const CNodePtr &prev_cnode, const PrimitivePtr &prim, |
|
|
|
const PrimitivePtr &prev_prim, size_t output_index, size_t input_index, |
|
|
|
size_t *edge_count) { |
|
|
|
std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name(); |
|
|
|
// If the edge between these two operators already has been added, then the edge will not be added again. |
|
|
|
if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, input_index - 1)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
EdgePtr edge_ptr; |
|
|
|
MS_LOG(INFO) << "Creating edge: " << edge_name; |
|
|
|
if (IsOperatorsInTwoSeparateLoops(prev_cnode, cnode)) { |
|
|
|
MS_LOG(INFO) << "prev_cnode_fullname: " << prev_cnode->fullname_with_scope() |
|
|
|
<< ", cnode_fullname: " << cnode->fullname_with_scope(); |
|
|
|
MS_LOG(INFO) << "The two operators in two separate for-loops, thus skip the edge."; |
|
|
|
return; |
|
|
|
} |
|
|
|
bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) || |
|
|
|
(ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name())); |
|
|
|
if (follow_strategy) { |
|
|
|
// Redistribution in not allowed on the edge. |
|
|
|
// Elementwise operators have the same strategy as their previous operators. |
|
|
|
edge_ptr = |
|
|
|
std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, input_index - 1, false, true); |
|
|
|
} else { |
|
|
|
edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, input_index - 1, false); |
|
|
|
} |
|
|
|
|
|
|
|
// Init costs for this edge |
|
|
|
if (edge_ptr->InitEdgeCost() != SUCCESS) { |
|
|
|
MS_LOG(EXCEPTION) << "Edge cost initialization failed"; |
|
|
|
} |
|
|
|
node_op_info->AddPrevEdge(edge_ptr); |
|
|
|
prev_op_info->AddSuccEdge(edge_ptr); |
|
|
|
entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr); |
|
|
|
MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and " << node_op_info->name(); |
|
|
|
(*edge_count)++; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { |
|
|
|
// Step 2 |
|
|
|
MS_LOG(INFO) << "Constructing edges for cost graph begins."; |
|
|
|
@@ -600,45 +651,12 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { |
|
|
|
PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>(); |
|
|
|
size_t output_index = 0; |
|
|
|
|
|
|
|
bool bool_result = (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) || |
|
|
|
(prev_prim->name() == DEPEND); |
|
|
|
while (bool_result) { |
|
|
|
while ((IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) || |
|
|
|
(prev_prim->name() == DEPEND)) { |
|
|
|
if (IsAutoParallelCareNode(prev_cnode)) { |
|
|
|
auto prev_op_info = prev_cnode->user_data<OperatorInfo>(); |
|
|
|
std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name(); |
|
|
|
// If the edge between these two operators already has been added, then the edge will not be added again. |
|
|
|
if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) { |
|
|
|
break; |
|
|
|
} |
|
|
|
EdgePtr edge_ptr; |
|
|
|
MS_LOG(INFO) << "Creating edge: " << edge_name; |
|
|
|
if (IsOperatorsInTwoSeparateLoops(prev_cnode, cnode)) { |
|
|
|
MS_LOG(INFO) << "prev_cnode_fullname: " << prev_cnode->fullname_with_scope() |
|
|
|
<< ", cnode_fullname: " << cnode->fullname_with_scope(); |
|
|
|
MS_LOG(INFO) << "The two operators in two separate for-loops, thus skip the edge."; |
|
|
|
break; |
|
|
|
} |
|
|
|
bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) || |
|
|
|
(ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name())); |
|
|
|
if (follow_strategy) { |
|
|
|
// Redistribution in not allowed on the edge. |
|
|
|
// Elementwise operators have the same strategy as their previous operators. |
|
|
|
edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, i - 1, false, true); |
|
|
|
} else { |
|
|
|
edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, i - 1, false); |
|
|
|
} |
|
|
|
|
|
|
|
// Init costs for this edge |
|
|
|
if (edge_ptr->InitEdgeCost() != SUCCESS) { |
|
|
|
MS_LOG(EXCEPTION) << "Edge cost initialization failed"; |
|
|
|
} |
|
|
|
node_op_info->AddPrevEdge(edge_ptr); |
|
|
|
prev_op_info->AddSuccEdge(edge_ptr); |
|
|
|
entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr); |
|
|
|
MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and " |
|
|
|
<< node_op_info->name(); |
|
|
|
edge_count++; |
|
|
|
|
|
|
|
CreateEdgeBetweenTwoOps(prev_op_info, node_op_info, cnode, prev_cnode, prim, prev_prim, output_index, i, |
|
|
|
&edge_count); |
|
|
|
break; |
|
|
|
} else if (prev_prim->name() == prim::kTupleGetItem) { |
|
|
|
// In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before |
|
|
|
@@ -673,8 +691,6 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { |
|
|
|
<< "and creating an edge between the Operator before " |
|
|
|
<< "'depend' and the Operator after 'depend'."; |
|
|
|
} |
|
|
|
bool_result = (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) || |
|
|
|
(prev_prim->name() == DEPEND); |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name(); |
|
|
|
@@ -832,7 +848,7 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { |
|
|
|
pre_operator_info = reshape_info; |
|
|
|
pre_stra_costs = reshape_info->strategy_cost(); |
|
|
|
} else { |
|
|
|
if (!FindReshapePreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) { |
|
|
|
if (!FindReshapePreNodeStraCosts(pre_node, &pre_operator_info, &out_index, 0)) { |
|
|
|
MS_LOG(EXCEPTION) << "FindReshapePreNodeStraCosts for reshape failed"; |
|
|
|
} |
|
|
|
pre_stra_costs = pre_operator_info->strategy_cost(); |
|
|
|
@@ -841,7 +857,7 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { |
|
|
|
int64_t in_index = 0; |
|
|
|
OperatorInfoPtr next_operator_info; |
|
|
|
std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs; |
|
|
|
bool find_next_node = FindReshapeNextNodeStraCosts(cnode, &next_operator_info, &in_index); |
|
|
|
bool find_next_node = FindReshapeNextNodeStraCosts(cnode, &next_operator_info, &in_index, 0); |
|
|
|
if (!find_next_node) { |
|
|
|
MS_LOG(INFO) << "FindReshapeNextNodeStraCosts for reshape failed"; |
|
|
|
} |
|
|
|
@@ -858,7 +874,7 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { |
|
|
|
bool is_prev_param = pre_node->isa<Parameter>(); |
|
|
|
if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) != |
|
|
|
SUCCESS) { |
|
|
|
MS_LOG(EXCEPTION) << "reshape genetate strategy_costs failed!"; |
|
|
|
MS_LOG(EXCEPTION) << "reshape generate strategy_costs failed!"; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|