diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc index 8d14e0f1c3..d93de5422b 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -145,7 +145,7 @@ TensorParam Complete2DInputs(const std::vector> &o std::shared_ptr ParseGraph(const std::vector> &ops, const std::vector> &input_tensor_names) { - std::shared_ptr graph(new Graph); + std::shared_ptr graph = std::make_shared(); if (ops.size() > SIZE_MAX / 2) { MS_LOG(EXCEPTION) << "Total number of operators is bigger than " << SIZE_MAX / 2; } @@ -178,7 +178,7 @@ size_t GetIndexInInputTensorNames(const std::vector> &i return index; } } - MS_LOG(INFO) << "Get index failed, using SIZE_MAX insted"; + MS_LOG(INFO) << "Get index failed, using SIZE_MAX instead"; return SIZE_MAX; } @@ -253,7 +253,7 @@ std::shared_ptr EliminateGraph(const std::shared_ptr &graph, index_list->at(j)--; } } - std::shared_ptr new_graph(new Graph); + std::shared_ptr new_graph = std::make_shared(); for (size_t i = 0; i < graph->nodes.size(); i++) { if (index_list->at(i) > SIZE_MAX / 2) { continue; diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc index c19d334c23..7502e84a1c 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc @@ -299,7 +299,13 @@ bool FindReshape(const CNodePtr &cnode, std::unordered_set *op_cach } // Find previous node of Reshape, then obtain its strategy_cost_ vector to get its layout vector. -bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index) { +bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index, + size_t curr_depth) { + if (curr_depth > MAX_RECURSIVE_DEPTH) { + MS_LOG(WARNING) << "When finding Reshape's previous node, exceeded the max recursive depth: " + << MAX_RECURSIVE_DEPTH; + return false; + } // if previous node is a parameter, handle it in the outsize. if (node->isa()) { return false; @@ -338,7 +344,7 @@ bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_op if (prim->name() == DEPEND && index != 1) { continue; } - if (!FindReshapePreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) { + if (!FindReshapePreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index, ++curr_depth)) { continue; } return true; @@ -350,7 +356,12 @@ bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_op // Find next node of Reshape, then obtain its strategy_cost_ vector to get its layout vector. // if reshape's output connect to several primitive, return the first layout found -bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index) { +bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index, + size_t curr_depth) { + if (curr_depth > MAX_RECURSIVE_DEPTH) { + MS_LOG(WARNING) << "When finding Reshape's next node, exceeded the max recursive depth: " << MAX_RECURSIVE_DEPTH; + return false; + } MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode->func_graph()); FuncGraphManagerPtr manager = cnode->func_graph()->manager(); @@ -379,7 +390,7 @@ bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_o MS_LOG(DEBUG) << "FindReshapeNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) << " " << (op_info != nullptr); - if (FindReshapeNextNodeStraCosts(use_apply, next_operator_info, in_index)) { + if (FindReshapeNextNodeStraCosts(use_apply, next_operator_info, in_index, ++curr_depth)) { return true; } } diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.h b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.h index 105893deb3..04367d1062 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.h +++ b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.h @@ -46,9 +46,11 @@ bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name bool FindReshape(const CNodePtr &cnode, std::unordered_set *op_cache); -bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index); +bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index, + size_t curr_depth); -bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index); +bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index, + size_t curr_depth); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 6bfe40fa6a..3c4978442d 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -21,6 +21,7 @@ namespace mindspore { namespace parallel { +constexpr size_t MAX_RECURSIVE_DEPTH = 1000; constexpr size_t PRELU_INPUTS_SIZE = 2; constexpr size_t PRELU_OUTPUTS_SIZE = 1; constexpr size_t PRELU_SECOND_INPUT_SIZE = 1; diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index c6ab5f1984..f97229ca56 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -234,6 +234,48 @@ void InitCostGraph() { entire_costgraph->Init(); } +void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const PrimitivePtr &prim, + const std::unordered_map &attrs, bool is_last_nodes, + StrategyMap *stra_map, const std::string &strategy_key_name) { + // In this case, the configured strategy should be extracted to help setting cost + StrategyPtr strategyPtr; + if (is_last_nodes) { + bool full_batch = ParallelContext::GetInstance()->full_batch(); + strategyPtr = GenerateBatchParallelStrategy(operator_info, prim); + if (full_batch) { + SetLastNodeStrategy(strategyPtr); + } + } else if (StrategyFound(attrs)) { + strategyPtr = parallel::ExtractStrategy(attrs); + } else { + strategyPtr = (*stra_map)[strategy_key_name]; + } + if (strategyPtr != nullptr) { + if (prim->name() == RESHAPE) { + MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; + } + // Set cost for this configured strategy + if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) { + MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed"; + } else if (FULLY_USE_DEVICES) { + // If configured to fully use devices, then checking for the user-specified strategy + int64_t used_devices = operator_info->used_devices(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size(); + // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel + if (used_devices == 1) { + return; + } + // 'used_devices == -1' means that 'used_devices_' is not set + if ((used_devices == -1) || LongToSize(used_devices) != total_device_num) { + MS_LOG(EXCEPTION) << "In configuration 'FULLY_USE_DEVICES' = True, " + << "but the specified strategy uses device: " << used_devices + << ", total devices: " << total_device_num; + } + } + } +} + OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes, StrategyMap *stra_map) { MS_EXCEPTION_IF_NULL(prim); @@ -290,9 +332,10 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & bool load_strategy_from_ckpt = StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); // If no strategy has been configured for this operator, then candidate strategies are generated for - // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy. - // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint . - if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt && !is_last_nodes) { + // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy; + // if strategy is set to load from checkpoint, it is preferred to load strategy from checkpoint. + bool is_gen_stra = (!StrategyFound(attrs) || prim->name() == CAST) && (!load_strategy_from_ckpt) && (!is_last_nodes); + if (is_gen_stra) { // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for // BatchParallelInfo operator operator_info->ComputeBatchSplitFlagList(); @@ -307,43 +350,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name(); } } else { - // In this case, the configured strategy should be extracted to help setting cost - StrategyPtr strategyPtr; - if (is_last_nodes) { - bool full_batch = ParallelContext::GetInstance()->full_batch(); - strategyPtr = GenerateBatchParallelStrategy(operator_info, prim); - if (full_batch) { - SetLastNodeStrategy(strategyPtr); - } - } else if (StrategyFound(attrs)) { - strategyPtr = parallel::ExtractStrategy(attrs); - } else { - strategyPtr = (*stra_map)[strategy_key_name]; - } - if (strategyPtr != nullptr) { - if (prim->name() == RESHAPE) { - MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; - } - // Set cost for this configured strategy - if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) { - MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed"; - } else if (FULLY_USE_DEVICES) { - // If configured to fully use devices, then checking for the user-specified strategy - int64_t used_devices = operator_info->used_devices(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size(); - // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel - if (used_devices == 1) { - return operator_info; - } - // 'used_devices == -1' means that 'used_devices_' is not set - if ((used_devices == -1) || LongToSize(used_devices) != total_device_num) { - MS_LOG(EXCEPTION) << "In configuration 'FULLY_USE_DEVICES' = True, " - << "but the specified strategy uses device: " << used_devices - << ", total devices: " << total_device_num; - } - } - } + SetStrategyToOperator(operator_info, prim, attrs, is_last_nodes, stra_map, strategy_key_name); } return operator_info; } @@ -451,6 +458,29 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node return SUCCESS; } +void SetOperatorToCNode(const OperatorInfoPtr ¤t_op_ptr, const PrimitivePtr &prim, const CNodePtr &cnode) { + if (current_op_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed."; + } else { + bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && + (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && + (current_op_ptr->name().find(prim->name()) == std::string::npos); + if (is_find_wrong) { + MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() + << " does not match the Prim: " << prim->name(); + } + + // Needed by rec_parser + ModifyInputsTensorNameListIfOperatorInfoCreated(current_op_ptr->name(), cnode->UniqueId()); + + cnode->set_user_data(current_op_ptr); + MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() + << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() + << ", CNode fullname_with_scope: " << cnode->fullname_with_scope() + << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); + } +} + // Using CNode's UniqueIdThroughCopys to construct nodes Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_nodes, const FuncGraphPtr &root) { MS_LOG(INFO) << "Constructing nodes for cost graph begins."; @@ -497,7 +527,8 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no if (search_cnode == from_cnode_to_info.end()) { size_t loop_index = 0; bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index); - if (DP_ALGO_SINGLE_LOOP && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) { + bool is_op_created = DP_ALGO_SINGLE_LOOP && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size()); + if (is_op_created) { const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]]; bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && @@ -543,27 +574,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no // Needed by rec_parser entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); } else { - auto current_op_ptr = search_cnode->second; - if (current_op_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed."; - } else { - bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && - (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && - (current_op_ptr->name().find(prim->name()) == std::string::npos); - if (is_find_wrong) { - MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() - << " does not match the Prim: " << prim->name(); - } - - // Needed by rec_parser - ModifyInputsTensorNameListIfOperatorInfoCreated(current_op_ptr->name(), cnode->UniqueId()); - - cnode->set_user_data(current_op_ptr); - MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() - << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() - << ", CNode fullname_with_scope: " << cnode->fullname_with_scope() - << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); - } + SetOperatorToCNode(search_cnode->second, prim, cnode); } } @@ -571,6 +582,46 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no return SUCCESS; } +void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const OperatorInfoPtr &node_op_info, + const CNodePtr &cnode, const CNodePtr &prev_cnode, const PrimitivePtr &prim, + const PrimitivePtr &prev_prim, size_t output_index, size_t input_index, + size_t *edge_count) { + std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name(); + // If the edge between these two operators already has been added, then the edge will not be added again. + if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, input_index - 1)) { + return; + } + EdgePtr edge_ptr; + MS_LOG(INFO) << "Creating edge: " << edge_name; + if (IsOperatorsInTwoSeparateLoops(prev_cnode, cnode)) { + MS_LOG(INFO) << "prev_cnode_fullname: " << prev_cnode->fullname_with_scope() + << ", cnode_fullname: " << cnode->fullname_with_scope(); + MS_LOG(INFO) << "The two operators in two separate for-loops, thus skip the edge."; + return; + } + bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) || + (ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name())); + if (follow_strategy) { + // Redistribution in not allowed on the edge. + // Elementwise operators have the same strategy as their previous operators. + edge_ptr = + std::make_shared(edge_name, prev_op_info, node_op_info, output_index, input_index - 1, false, true); + } else { + edge_ptr = std::make_shared(edge_name, prev_op_info, node_op_info, output_index, input_index - 1, false); + } + + // Init costs for this edge + if (edge_ptr->InitEdgeCost() != SUCCESS) { + MS_LOG(EXCEPTION) << "Edge cost initialization failed"; + } + node_op_info->AddPrevEdge(edge_ptr); + prev_op_info->AddSuccEdge(edge_ptr); + entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr); + MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and " << node_op_info->name(); + (*edge_count)++; + return; +} + void ConstructCostGraphEdges(const std::vector &all_nodes) { // Step 2 MS_LOG(INFO) << "Constructing edges for cost graph begins."; @@ -600,45 +651,12 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast(); size_t output_index = 0; - bool bool_result = (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) || - (prev_prim->name() == DEPEND); - while (bool_result) { + while ((IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) || + (prev_prim->name() == DEPEND)) { if (IsAutoParallelCareNode(prev_cnode)) { auto prev_op_info = prev_cnode->user_data(); - std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name(); - // If the edge between these two operators already has been added, then the edge will not be added again. - if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) { - break; - } - EdgePtr edge_ptr; - MS_LOG(INFO) << "Creating edge: " << edge_name; - if (IsOperatorsInTwoSeparateLoops(prev_cnode, cnode)) { - MS_LOG(INFO) << "prev_cnode_fullname: " << prev_cnode->fullname_with_scope() - << ", cnode_fullname: " << cnode->fullname_with_scope(); - MS_LOG(INFO) << "The two operators in two separate for-loops, thus skip the edge."; - break; - } - bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) || - (ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name())); - if (follow_strategy) { - // Redistribution in not allowed on the edge. - // Elementwise operators have the same strategy as their previous operators. - edge_ptr = std::make_shared(edge_name, prev_op_info, node_op_info, output_index, i - 1, false, true); - } else { - edge_ptr = std::make_shared(edge_name, prev_op_info, node_op_info, output_index, i - 1, false); - } - - // Init costs for this edge - if (edge_ptr->InitEdgeCost() != SUCCESS) { - MS_LOG(EXCEPTION) << "Edge cost initialization failed"; - } - node_op_info->AddPrevEdge(edge_ptr); - prev_op_info->AddSuccEdge(edge_ptr); - entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr); - MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and " - << node_op_info->name(); - edge_count++; - + CreateEdgeBetweenTwoOps(prev_op_info, node_op_info, cnode, prev_cnode, prim, prev_prim, output_index, i, + &edge_count); break; } else if (prev_prim->name() == prim::kTupleGetItem) { // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before @@ -673,8 +691,6 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { << "and creating an edge between the Operator before " << "'depend' and the Operator after 'depend'."; } - bool_result = (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) || - (prev_prim->name() == DEPEND); } } MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name(); @@ -832,7 +848,7 @@ void ReshapeCostCompute(const std::vector &all_nodes) { pre_operator_info = reshape_info; pre_stra_costs = reshape_info->strategy_cost(); } else { - if (!FindReshapePreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) { + if (!FindReshapePreNodeStraCosts(pre_node, &pre_operator_info, &out_index, 0)) { MS_LOG(EXCEPTION) << "FindReshapePreNodeStraCosts for reshape failed"; } pre_stra_costs = pre_operator_info->strategy_cost(); @@ -841,7 +857,7 @@ void ReshapeCostCompute(const std::vector &all_nodes) { int64_t in_index = 0; OperatorInfoPtr next_operator_info; std::vector> next_stra_costs; - bool find_next_node = FindReshapeNextNodeStraCosts(cnode, &next_operator_info, &in_index); + bool find_next_node = FindReshapeNextNodeStraCosts(cnode, &next_operator_info, &in_index, 0); if (!find_next_node) { MS_LOG(INFO) << "FindReshapeNextNodeStraCosts for reshape failed"; } @@ -858,7 +874,7 @@ void ReshapeCostCompute(const std::vector &all_nodes) { bool is_prev_param = pre_node->isa(); if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) != SUCCESS) { - MS_LOG(EXCEPTION) << "reshape genetate strategy_costs failed!"; + MS_LOG(EXCEPTION) << "reshape generate strategy_costs failed!"; } } }