|
|
|
@@ -956,64 +956,71 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { |
|
|
|
if (IsValueNode<RefKey>(node)) { |
|
|
|
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph); |
|
|
|
if (param_v.size() != 1) { |
|
|
|
MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is " |
|
|
|
<< param_v.size(); |
|
|
|
} |
|
|
|
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>(); |
|
|
|
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) { |
|
|
|
return std::make_pair(nullptr, true); |
|
|
|
} |
|
|
|
return std::make_pair(node, true); |
|
|
|
} |
|
|
|
return std::make_pair(nullptr, false); |
|
|
|
} |
|
|
|
|
|
|
|
// Only used for InsertMirrorOps |
|
|
|
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { |
|
|
|
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) { |
|
|
|
return std::make_pair(nullptr, false); |
|
|
|
} else if (node->isa<Parameter>()) { |
|
|
|
} |
|
|
|
|
|
|
|
if (node->isa<Parameter>()) { |
|
|
|
auto param_ptr = node->user_data<parallel::TensorLayout>(); |
|
|
|
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) { |
|
|
|
return std::make_pair(nullptr, false); |
|
|
|
} else { |
|
|
|
return std::make_pair(node, false); |
|
|
|
} |
|
|
|
} else if (node->isa<ValueNode>()) { |
|
|
|
if (IsValueNode<RefKey>(node)) { |
|
|
|
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph); |
|
|
|
if (param_v.size() != 1) { |
|
|
|
MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is " |
|
|
|
<< param_v.size(); |
|
|
|
} |
|
|
|
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>(); |
|
|
|
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) { |
|
|
|
return std::make_pair(nullptr, true); |
|
|
|
} else { |
|
|
|
return std::make_pair(node, true); |
|
|
|
} |
|
|
|
return std::make_pair(node, false); |
|
|
|
} |
|
|
|
|
|
|
|
if (node->isa<ValueNode>()) { |
|
|
|
return FindParameterByValueNode(node, func_graph); |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (!IsValueNode<Primitive>(cnode->input(0))) { |
|
|
|
for (size_t index = 0; index < cnode->inputs().size(); ++index) { |
|
|
|
if (!FindParameter(cnode->input(index), func_graph).first) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
return FindParameter(cnode->input(index), func_graph); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data<OperatorInfo>()) { |
|
|
|
return std::make_pair(node, false); |
|
|
|
} |
|
|
|
|
|
|
|
if (IsParallelCareNode(cnode)) { |
|
|
|
return std::make_pair(nullptr, false); |
|
|
|
} else { |
|
|
|
CNodePtr cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (!IsValueNode<Primitive>(cnode->input(0))) { |
|
|
|
for (size_t index = 0; index < cnode->inputs().size(); ++index) { |
|
|
|
if (!FindParameter(cnode->input(index), func_graph).first) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
return FindParameter(cnode->input(index), func_graph); |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data<OperatorInfo>()) { |
|
|
|
return std::make_pair(node, false); |
|
|
|
} |
|
|
|
if (IsParallelCareNode(cnode)) { |
|
|
|
return std::make_pair(nullptr, false); |
|
|
|
} else { |
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(prim_anf_node); |
|
|
|
for (size_t index = 0; index < cnode->inputs().size(); ++index) { |
|
|
|
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
if ((prim->name() == DEPEND || prim->name() == LOAD) && index != 1) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (!FindParameter(cnode->input(index), func_graph).first) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
return FindParameter(cnode->input(index), func_graph); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(prim_anf_node); |
|
|
|
for (size_t index = 0; index < cnode->inputs().size(); ++index) { |
|
|
|
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
if ((prim->name() == DEPEND || prim->name() == LOAD) && index != 1) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (!FindParameter(cnode->input(index), func_graph).first) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
return FindParameter(cnode->input(index), func_graph); |
|
|
|
} |
|
|
|
return std::make_pair(nullptr, false); |
|
|
|
} |
|
|
|
@@ -1101,6 +1108,25 @@ static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &par |
|
|
|
MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type; |
|
|
|
} |
|
|
|
|
|
|
|
static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) { |
|
|
|
if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) { |
|
|
|
MS_LOG(INFO) << "Input is ValueList, skip it."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if ((node->inputs().size() == 2) && |
|
|
|
(AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) { |
|
|
|
MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (mirror_ops.size() != node_size - 1) { |
|
|
|
MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is " |
|
|
|
<< node_size - 1; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
size_t node_size = node->inputs().size(); |
|
|
|
@@ -1113,21 +1139,11 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons |
|
|
|
node_size--; |
|
|
|
} |
|
|
|
} |
|
|
|
if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) { |
|
|
|
MS_LOG(INFO) << "Input is ValueList, skip it."; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
if ((node->inputs().size() == 2) && |
|
|
|
(AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) { |
|
|
|
MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node"; |
|
|
|
if (!CheckInsertMirrorOps(mirror_ops, node, node_size)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
if (mirror_ops.size() != node_size - 1) { |
|
|
|
MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is " |
|
|
|
<< node_size - 1; |
|
|
|
} |
|
|
|
for (size_t index = 1; index < node_size; ++index) { |
|
|
|
OperatorVector backward_op = mirror_ops[index - 1]; |
|
|
|
if (backward_op.empty()) { |
|
|
|
@@ -1181,15 +1197,15 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons |
|
|
|
// pipeline mirror would not be set, which should be supported later |
|
|
|
AddCommOpFusionType(comm_op, param_node_pair.first); |
|
|
|
} |
|
|
|
} else { |
|
|
|
for (auto &op : backward_op) { |
|
|
|
AnfNodePtr pre_node = node->input(index); |
|
|
|
InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name); |
|
|
|
auto comm_op = node->input(index)->cast<CNodePtr>(); |
|
|
|
// add fusion flag |
|
|
|
// pipeline mirror would not be set, which should be supported later |
|
|
|
AddCommOpFusionType(comm_op, param_node_pair.first); |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
for (auto &op : backward_op) { |
|
|
|
AnfNodePtr pre_node = node->input(index); |
|
|
|
InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name); |
|
|
|
auto comm_op = node->input(index)->cast<CNodePtr>(); |
|
|
|
// add fusion flag |
|
|
|
// pipeline mirror would not be set, which should be supported later |
|
|
|
AddCommOpFusionType(comm_op, param_node_pair.first); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -1849,13 +1865,29 @@ void SetLastNodeStrategy(const StrategyPtr strategyPtr) { |
|
|
|
strategyPtr->ResetInputs(strategys); |
|
|
|
} |
|
|
|
|
|
|
|
static bool CheckExtractInfomation(const CNodePtr &cnode) { |
|
|
|
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); |
|
|
|
if ((prim->name() == MAKE_TUPLE) || (prim->name() == MAKE_LIST) || (prim->name() == RECEIVE)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (!IsParallelCareNode(cnode)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_training) { |
|
|
|
// load strategy map from checkpoint |
|
|
|
StrategyMap stra_map; |
|
|
|
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { |
|
|
|
if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { |
|
|
|
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; |
|
|
|
} |
|
|
|
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn() && |
|
|
|
(StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS)) { |
|
|
|
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; |
|
|
|
} |
|
|
|
vector<std::string> last_forward_node_ids; |
|
|
|
if (!is_training) { |
|
|
|
@@ -1865,76 +1897,71 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini |
|
|
|
|
|
|
|
for (auto &node : all_nodes) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { |
|
|
|
if (!CheckExtractInfomation(cnode)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
SetVirtualDatasetStrategy(cnode); |
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); |
|
|
|
if (prim->name() == MAKE_TUPLE || prim->name() == MAKE_LIST || prim->name() == RECEIVE) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
auto attrs = prim->attrs(); |
|
|
|
MS_LOG(INFO) << "extract information: node: " << node->ToString() << " prim " << prim->name(); |
|
|
|
if (IsParallelCareNode(cnode)) { |
|
|
|
std::vector<Shapes> shape_list = ExtractShape(cnode); |
|
|
|
if (shape_list.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape"; |
|
|
|
} |
|
|
|
OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list); |
|
|
|
if (operator_ == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->name() << " OperatorInstance failed"; |
|
|
|
} |
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
std::vector<ValuePtr> input_value; |
|
|
|
for (size_t index = 1; index < inputs.size(); ++index) { |
|
|
|
if (inputs[index]->isa<ValueNode>()) { |
|
|
|
input_value.push_back(GetValueNode(inputs[index])); |
|
|
|
} else { |
|
|
|
input_value.emplace_back(nullptr); |
|
|
|
} |
|
|
|
} |
|
|
|
StrategyPtr strategyPtr = nullptr; |
|
|
|
(*operator_).set_input_value(input_value); |
|
|
|
(*operator_).set_outputs_dtype(cnode->Type()); |
|
|
|
(*operator_).set_cnode(cnode); |
|
|
|
if (prim->name() == RESHAPE) { |
|
|
|
cnode->set_user_data<OperatorInfo>(operator_); |
|
|
|
|
|
|
|
std::vector<Shapes> shape_list = ExtractShape(cnode); |
|
|
|
if (shape_list.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape"; |
|
|
|
} |
|
|
|
OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list); |
|
|
|
MS_EXCEPTION_IF_NULL(operator_); |
|
|
|
|
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
std::vector<ValuePtr> input_value; |
|
|
|
for (size_t index = 1; index < inputs.size(); ++index) { |
|
|
|
if (inputs[index]->isa<ValueNode>()) { |
|
|
|
input_value.push_back(GetValueNode(inputs[index])); |
|
|
|
continue; |
|
|
|
} |
|
|
|
// load strategy checkpoint |
|
|
|
// key of strategy map |
|
|
|
std::string strategy_key_name = ""; |
|
|
|
auto param_names = NodeParameterName(cnode); |
|
|
|
if (!param_names.empty()) { |
|
|
|
strategy_key_name = prim->name() + "_" + param_names[0].first; |
|
|
|
} |
|
|
|
bool load_strategy_from_ckpt = |
|
|
|
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); |
|
|
|
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != |
|
|
|
last_forward_node_ids.end(); |
|
|
|
bool full_batch = ParallelContext::GetInstance()->full_batch(); |
|
|
|
if ((is_last_nodes && !full_batch) || (!StrategyFound(attrs) && !load_strategy_from_ckpt)) { |
|
|
|
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() |
|
|
|
<< " is empty, using batch parallel"; |
|
|
|
strategyPtr = GenerateBatchParallelStrategy(operator_, prim); |
|
|
|
} else if (StrategyFound(attrs)) { |
|
|
|
strategyPtr = ExtractStrategy(attrs); |
|
|
|
} else { |
|
|
|
strategyPtr = stra_map[strategy_key_name]; |
|
|
|
} |
|
|
|
if (strategyPtr != nullptr) { |
|
|
|
if (is_last_nodes && full_batch) { |
|
|
|
SetLastNodeStrategy(strategyPtr); |
|
|
|
} |
|
|
|
if (operator_->Init(strategyPtr) == FAILED) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; |
|
|
|
} |
|
|
|
cnode->set_user_data<OperatorInfo>(operator_); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr"; |
|
|
|
} |
|
|
|
input_value.emplace_back(nullptr); |
|
|
|
} |
|
|
|
StrategyPtr strategyPtr = nullptr; |
|
|
|
(*operator_).set_input_value(input_value); |
|
|
|
(*operator_).set_outputs_dtype(cnode->Type()); |
|
|
|
(*operator_).set_cnode(cnode); |
|
|
|
if (prim->name() == RESHAPE) { |
|
|
|
cnode->set_user_data<OperatorInfo>(operator_); |
|
|
|
continue; |
|
|
|
} |
|
|
|
// load strategy checkpoint |
|
|
|
// key of strategy map |
|
|
|
std::string strategy_key_name = ""; |
|
|
|
auto param_names = NodeParameterName(cnode); |
|
|
|
if (!param_names.empty()) { |
|
|
|
strategy_key_name = prim->name() + "_" + param_names[0].first; |
|
|
|
} |
|
|
|
bool load_strategy_from_ckpt = |
|
|
|
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); |
|
|
|
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != |
|
|
|
last_forward_node_ids.end(); |
|
|
|
bool full_batch = ParallelContext::GetInstance()->full_batch(); |
|
|
|
if ((is_last_nodes && !full_batch) || (!StrategyFound(attrs) && !load_strategy_from_ckpt)) { |
|
|
|
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() |
|
|
|
<< " is empty, using batch parallel"; |
|
|
|
strategyPtr = GenerateBatchParallelStrategy(operator_, prim); |
|
|
|
} else if (StrategyFound(attrs)) { |
|
|
|
strategyPtr = ExtractStrategy(attrs); |
|
|
|
} else { |
|
|
|
strategyPtr = stra_map[strategy_key_name]; |
|
|
|
} |
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(strategyPtr); |
|
|
|
if (is_last_nodes && full_batch) { |
|
|
|
SetLastNodeStrategy(strategyPtr); |
|
|
|
} |
|
|
|
if (operator_->Init(strategyPtr) == FAILED) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; |
|
|
|
} |
|
|
|
cnode->set_user_data<OperatorInfo>(operator_); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1994,9 +2021,9 @@ std::shared_ptr<TensorLayout> GetOutputLayoutFromCNode(const CNodePtr &cnode, si |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(distribute_operator); |
|
|
|
if (distribute_operator->outputs_tensor_info().size() < output_index) { |
|
|
|
if (distribute_operator->outputs_tensor_info().size() <= output_index) { |
|
|
|
MS_LOG(EXCEPTION) << "outputs_tensor_info size is " << distribute_operator->inputs_tensor_info().size() |
|
|
|
<< ", must be less than output_index " << output_index; |
|
|
|
<< ", must be greater than output_index " << output_index; |
|
|
|
} |
|
|
|
TensorInfo tensorinfo_out = distribute_operator->outputs_tensor_info()[output_index]; |
|
|
|
TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); |
|
|
|
|