|
|
|
@@ -1512,7 +1512,87 @@ Status ValidStageCheck(const std::vector<int32_t> &stages, int32_t strategy_stag |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { |
|
|
|
// find previous parallel care node. |
|
|
|
bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) { |
|
|
|
MS_EXCEPTION_IF_NULL(unique_ids); |
|
|
|
// if previous node is a parameter, handle it in the outsize. |
|
|
|
if (node->isa<Parameter>()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
CNodePtr cnode = node->cast<CNodePtr>(); |
|
|
|
if (!IsValueNode<Primitive>(cnode->input(0))) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); |
|
|
|
if (IsParallelCareNode(cnode) && prim->name() != MAKE_TUPLE && prim->name() != MAKE_LIST) { |
|
|
|
unique_ids->push_back(cnode->UniqueId()); |
|
|
|
return true; |
|
|
|
} |
|
|
|
bool find = false; |
|
|
|
for (size_t index = 0; index < cnode->inputs().size(); ++index) { |
|
|
|
if (prim->name() == DEPEND && index != 1) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (FindPreNodes(cnode->inputs()[index], unique_ids)) { |
|
|
|
find = true; |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
return find; |
|
|
|
} |
|
|
|
|
|
|
|
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, vector<std::string> *unique_ids) { |
|
|
|
MS_EXCEPTION_IF_NULL(unique_ids); |
|
|
|
for (auto &node : all_nodes) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); |
|
|
|
if (prim->name() == RETURN) { |
|
|
|
if (!FindPreNodes(cnode, unique_ids)) { |
|
|
|
MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph"; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim) { |
|
|
|
MS_EXCEPTION_IF_NULL(operator_); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
StrategyPtr strategyPtr; |
|
|
|
std::shared_ptr<Strategys> strategy_v_ptr = operator_->GenerateBatchStrategies(); |
|
|
|
MS_EXCEPTION_IF_NULL(strategy_v_ptr); |
|
|
|
strategyPtr = NewStrategy(0, *strategy_v_ptr); |
|
|
|
std::vector<ValuePtr> elements; |
|
|
|
for (size_t i = 0; i < strategy_v_ptr->size(); i++) { |
|
|
|
elements.push_back(MakeValue((*strategy_v_ptr)[i])); |
|
|
|
} |
|
|
|
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements); |
|
|
|
// display the strategy generated by batch parallel |
|
|
|
auto attrs = prim->attrs(); |
|
|
|
attrs[GEN_STRATEGY] = strategy; |
|
|
|
(void)prim->SetAttrs(attrs); |
|
|
|
MS_LOG(INFO) << "prim " << prim->name() << " batch parallel strategy is " << attrs[GEN_STRATEGY]->ToString(); |
|
|
|
return strategyPtr; |
|
|
|
} |
|
|
|
|
|
|
|
void SetLastNodeStrategy(const StrategyPtr strategyPtr) { |
|
|
|
auto strategys = strategyPtr->GetInputDim(); |
|
|
|
for (size_t i = 0; i < strategys.size(); ++i) { |
|
|
|
for (size_t j = 0; j < strategys[i].size(); ++j) { |
|
|
|
strategys[i][j] = 1; |
|
|
|
} |
|
|
|
} |
|
|
|
strategyPtr->ResetInputs(strategys); |
|
|
|
} |
|
|
|
|
|
|
|
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_training) { |
|
|
|
// load strategy map from checkpoint |
|
|
|
StrategyMap stra_map; |
|
|
|
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { |
|
|
|
@@ -1520,7 +1600,11 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { |
|
|
|
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
vector<std::string> last_forward_node_ids; |
|
|
|
if (!is_training) { |
|
|
|
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids); |
|
|
|
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; |
|
|
|
} |
|
|
|
// Get global rank after the checkpoint? |
|
|
|
int32_t global_rank = ParallelContext::GetInstance()->global_rank(); |
|
|
|
std::vector<int32_t> stages = ParallelContext::GetInstance()->stage(); |
|
|
|
@@ -1572,30 +1656,22 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { |
|
|
|
} |
|
|
|
bool load_strategy_from_ckpt = |
|
|
|
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); |
|
|
|
if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { |
|
|
|
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != |
|
|
|
last_forward_node_ids.end(); |
|
|
|
bool full_batch = ParallelContext::GetInstance()->full_batch(); |
|
|
|
if ((is_last_nodes && !full_batch) || (!StrategyFound(attrs) && !load_strategy_from_ckpt)) { |
|
|
|
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() |
|
|
|
<< " is empty, using batch parallel"; |
|
|
|
std::shared_ptr<Strategys> strategy_v_ptr = operator_->GenerateBatchStrategies(); |
|
|
|
if (strategy_v_ptr == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure:Generate batch parallel strategy failed"; |
|
|
|
} |
|
|
|
std::vector<ValuePtr> elements; |
|
|
|
for (size_t i = 0; i < strategy_v_ptr->size(); i++) { |
|
|
|
elements.push_back(MakeValue((*strategy_v_ptr)[i])); |
|
|
|
} |
|
|
|
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements); |
|
|
|
// display the strategy generated by batch parallel |
|
|
|
attrs[GEN_STRATEGY] = strategy; |
|
|
|
(void)prim->SetAttrs(attrs); |
|
|
|
MS_LOG(INFO) << "node " << node->ToString() << " prim " << prim->name() << " batch parallel strategy is " |
|
|
|
<< attrs[GEN_STRATEGY]->ToString(); |
|
|
|
strategyPtr = NewStrategy(0, *strategy_v_ptr); |
|
|
|
strategyPtr = GenerateBatchParallelStrategy(operator_, prim); |
|
|
|
} else if (load_strategy_from_ckpt) { |
|
|
|
strategyPtr = stra_map[strategy_key_name]; |
|
|
|
} else { |
|
|
|
strategyPtr = ExtractStrategy(attrs); |
|
|
|
} |
|
|
|
if (strategyPtr != nullptr) { |
|
|
|
if (is_last_nodes && full_batch) { |
|
|
|
SetLastNodeStrategy(strategyPtr); |
|
|
|
} |
|
|
|
(*operator_).set_stage_id(strategyPtr->GetInputStage()); |
|
|
|
MS_LOG(INFO) << "Extract stage id for op " << prim->name() << " is " << (*operator_).stage_id(); |
|
|
|
if (ValidStageCheck(stages, (*operator_).stage_id()) == FAILED) { |
|
|
|
@@ -2854,7 +2930,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) |
|
|
|
} |
|
|
|
|
|
|
|
// extract shape and strategy, set operator_info |
|
|
|
ExtractInformation(all_nodes); |
|
|
|
ExtractInformation(all_nodes, root->has_flag(TRAINING)); |
|
|
|
ReshapeInit(all_nodes); |
|
|
|
} |
|
|
|
|
|
|
|
|