diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 9cd5e35e45..822daac1d0 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -233,7 +233,8 @@ void InitCostGraph() { entire_costgraph->Init(); } -OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { +OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes, + StrategyMap *stra_map) { MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(cnode); auto attrs = prim->attrs(); @@ -290,7 +291,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & // If no strategy has been configured for this operator, then candidate strategies are generated for // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy. // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint . - if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) { + if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt && !is_last_nodes) { // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for // BatchParallelInfo operator operator_info->ComputeBatchSplitFlagList(); @@ -307,10 +308,16 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & } else { // In this case, the configured strategy should be extracted to help setting cost StrategyPtr strategyPtr; - if (load_strategy_from_ckpt) { - strategyPtr = (*stra_map)[strategy_key_name]; - } else { + if (is_last_nodes) { + bool full_batch = ParallelContext::GetInstance()->full_batch(); + strategyPtr = GenerateBatchParallelStrategy(operator_info, prim); + if (full_batch) { + SetLastNodeStrategy(strategyPtr); + } + } else if (StrategyFound(attrs)) { strategyPtr = parallel::ExtractStrategy(attrs); + } else { + strategyPtr = (*stra_map)[strategy_key_name]; } if (strategyPtr != nullptr) { if (prim->name() == RESHAPE) { @@ -341,8 +348,10 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & } // Using CNode's UniqueIds to construct nodes -Status ConstructCostGraphNodesByUniqueId(const std::vector &all_nodes, const FuncGraphPtr &) { +Status ConstructCostGraphNodesByUniqueId(const std::vector &all_nodes, const FuncGraphPtr &root) { MS_LOG(INFO) << "Constructing nodes for cost graph begins."; + entire_costgraph = std::make_shared(); + entire_costgraph->SetDeviceMemoryAndCostParameter(); // The map from CNode's UniqueId to its operatorInfo std::map from_cnode_to_info; // The operator_infos in a loop @@ -356,7 +365,12 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; } } - + std::vector last_forward_node_ids; + if (!root->has_flag(TRAINING)) { + FindLastNodesUniqueId(all_nodes, &last_forward_node_ids); + MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; + } + // Step 1 for (auto &node : all_nodes) { // NOTE: we only care about splittable Primitive operators auto cnode = node->cast(); @@ -401,7 +415,9 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), current_op_ptr)); continue; } - auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map); + bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != + last_forward_node_ids.end(); + auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map); if (operator_info == nullptr) { return FAILED; } @@ -436,8 +452,10 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node } // Using CNode's UniqueIdThroughCopys to construct nodes -Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_nodes, const FuncGraphPtr &) { +Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_nodes, const FuncGraphPtr &root) { MS_LOG(INFO) << "Constructing nodes for cost graph begins."; + entire_costgraph = std::make_shared(); + entire_costgraph->SetDeviceMemoryAndCostParameter(); // The map from CNode's UniqueIdThroughCopy to its operatorInfo std::map from_cnode_to_info; // The operator_infos in a loop @@ -451,6 +469,11 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; } } + std::vector last_forward_node_ids; + if (!root->has_flag(TRAINING)) { + FindLastNodesUniqueId(all_nodes, &last_forward_node_ids); + MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; + } for (auto &node : all_nodes) { // NOTE: we only care about splittable Primitive operators auto cnode = node->cast(); @@ -496,7 +519,9 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no continue; } // In this case, the corresponding OperatorInfo is not created, create the new one. - auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map); + bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != + last_forward_node_ids.end(); + auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map); if (operator_info == nullptr) { return FAILED; } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index f823a302c2..d13058b698 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1638,7 +1638,7 @@ bool FindPreNodes(const AnfNodePtr &node, vector *unique_ids) { return find; } -void FindLastNodesUniqueId(const std::vector &all_nodes, vector *unique_ids) { +void FindLastNodesUniqueId(const std::vector &all_nodes, std::vector *unique_ids) { MS_EXCEPTION_IF_NULL(unique_ids); for (auto &node : all_nodes) { auto cnode = node->cast(); @@ -1754,10 +1754,10 @@ void ExtractInformation(const std::vector &all_nodes, bool is_traini MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() << " is empty, using batch parallel"; strategyPtr = GenerateBatchParallelStrategy(operator_, prim); - } else if (load_strategy_from_ckpt) { - strategyPtr = stra_map[strategy_key_name]; - } else { + } else if (StrategyFound(attrs)) { strategyPtr = ExtractStrategy(attrs); + } else { + strategyPtr = stra_map[strategy_key_name]; } if (strategyPtr != nullptr) { if (is_last_nodes && full_batch) { diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index a9c6a436d7..0aa875c5ce 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -165,6 +165,10 @@ bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr ¶meter); void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode, const AnfNodePtr ¶meter, size_t index); + +void SetLastNodeStrategy(const StrategyPtr strategyPtr); + +void FindLastNodesUniqueId(const std::vector &all_nodes, std::vector *unique_ids); } // namespace parallel } // namespace mindspore diff --git a/tests/ut/python/parallel/test_eval.py b/tests/ut/python/parallel/test_eval.py index eb777c4d8c..588fc41be3 100644 --- a/tests/ut/python/parallel/test_eval.py +++ b/tests/ut/python/parallel/test_eval.py @@ -67,3 +67,20 @@ def test_train_and_eval(): _executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True) context.reset_auto_parallel_context() + +def test_train_and_eval_auto(): + context.set_context(save_graphs=True, mode=0) + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16) + strategy1 = ((4, 4), (4, 4)) + strategy2 = ((4, 4),) + net = Net(_w1, strategy1, strategy2) + eval_net = EvalNet(net, strategy2=strategy2) + net.set_auto_parallel() + net.set_train() + _executor.compile(net, _x, _b, phase='train', auto_parallel_mode=True) + + eval_net.set_train(mode=False) + eval_net.set_auto_parallel() + _executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True) + + context.reset_auto_parallel_context()