|
|
|
@@ -808,11 +808,17 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Only used for InsertMirrorOps |
|
|
|
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { |
|
|
|
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) { |
|
|
|
return std::make_pair(nullptr, false); |
|
|
|
} else if (node->isa<Parameter>()) { |
|
|
|
return std::make_pair(node, false); |
|
|
|
auto param_ptr = node->user_data<parallel::TensorLayout>(); |
|
|
|
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) { |
|
|
|
return std::make_pair(nullptr, false); |
|
|
|
} else { |
|
|
|
return std::make_pair(node, false); |
|
|
|
} |
|
|
|
} else if (node->isa<ValueNode>()) { |
|
|
|
if (IsValueNode<RefKey>(node)) { |
|
|
|
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph); |
|
|
|
@@ -820,7 +826,12 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap |
|
|
|
MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is " |
|
|
|
<< param_v.size(); |
|
|
|
} |
|
|
|
return std::make_pair(node, true); |
|
|
|
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>(); |
|
|
|
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) { |
|
|
|
return std::make_pair(nullptr, true); |
|
|
|
} else { |
|
|
|
return std::make_pair(node, true); |
|
|
|
} |
|
|
|
} |
|
|
|
return std::make_pair(nullptr, false); |
|
|
|
} else { |
|
|
|
@@ -1002,7 +1013,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo |
|
|
|
MirrorOps mirror_ops = distribute_operator->mirror_ops(); |
|
|
|
VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op(); |
|
|
|
// insert mirror op |
|
|
|
if (!mirror_ops.empty() && !distribute_operator->opt_shard_flag()) { |
|
|
|
if (!mirror_ops.empty()) { |
|
|
|
MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name(); |
|
|
|
InsertMirrorOps(mirror_ops, node); |
|
|
|
} |
|
|
|
@@ -1374,39 +1385,51 @@ std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &graph, const AnfNode |
|
|
|
return std::make_pair(nullptr, 0); |
|
|
|
} |
|
|
|
|
|
|
|
void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator, |
|
|
|
const CNodePtr &cnode, const AnfNodePtr ¶meter, size_t index) { |
|
|
|
MS_EXCEPTION_IF_NULL(distribute_operator); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res, const AnfNodePtr ¶meter) { |
|
|
|
Operator op = CreateAllGatherOp(group); |
|
|
|
MS_EXCEPTION_IF_NULL(res.first); |
|
|
|
MS_EXCEPTION_IF_NULL(parameter); |
|
|
|
std::vector<Group> dev_group; |
|
|
|
// create communication group for allgather operator |
|
|
|
if (distribute_operator->CreateGroupByTensorMap(tensor_layout->origin_tensor_map().array(), &dev_group) == |
|
|
|
Status::SUCCESS && |
|
|
|
!dev_group.empty()) { |
|
|
|
// set optimizer shard split flag to avoid inserting mirror_ops |
|
|
|
distribute_operator->set_opt_shard_flag(true); |
|
|
|
// insert allgather operator between shard parameter and cnode |
|
|
|
Operator op = CreateAllGatherOp(dev_group[0].name()); |
|
|
|
auto graph = cnode->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
InsertNode(op, cnode, index, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER); |
|
|
|
// set communication group in tensor layout for checkpoint saving |
|
|
|
tensor_layout->set_opt_shard_group(dev_group[0].name()); |
|
|
|
// add fusion flag |
|
|
|
auto allgather = cnode->input(index)->cast<CNodePtr>(); |
|
|
|
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0)); |
|
|
|
auto attrs = prim->attrs(); |
|
|
|
// enable fusion flag later when it's supported in backend |
|
|
|
attrs["fusion"] = MakeValue(0); |
|
|
|
prim->SetAttrs(attrs); |
|
|
|
MS_LOG(INFO) << "Parallel optimizer is applied on " << parameter->ToString(); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Parallel optimizer applied on " << parameter->ToString() << "failed!"; |
|
|
|
auto cnode = res.first->cast<CNodePtr>(); |
|
|
|
auto graph = cnode->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER); |
|
|
|
// add fusion flag |
|
|
|
auto allgather = cnode->input(res.second)->cast<CNodePtr>(); |
|
|
|
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0)); |
|
|
|
auto attrs = prim->attrs(); |
|
|
|
// enable fusion flag later when it's supported in backend |
|
|
|
attrs["fusion"] = MakeValue(0); |
|
|
|
prim->SetAttrs(attrs); |
|
|
|
} |
|
|
|
|
|
|
|
void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, |
|
|
|
const std::string &opt_shard_group) { |
|
|
|
if (opt_shard_group.empty()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
FuncGraphManagerPtr manager = root->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto param_sub_set = manager->node_users()[parameter]; |
|
|
|
for (auto ¶m_pair : param_sub_set) { |
|
|
|
auto cnode = param_pair.first->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (cnode->in_forward_flag()) { |
|
|
|
OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>(); |
|
|
|
if (distribute_operator == nullptr) { |
|
|
|
MS_LOG(WARNING) << "Parallel optimizer: " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; |
|
|
|
} else if (IntToSize(param_pair.second - 1) >= distribute_operator->inputs_tensor_info().size()) { |
|
|
|
MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is " |
|
|
|
<< distribute_operator->inputs_tensor_info().size(); |
|
|
|
} |
|
|
|
// insert allgather operator between shard parameter and cnode |
|
|
|
InsertAllGatherOp(opt_shard_group, param_pair, parameter); |
|
|
|
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, int> &res) { |
|
|
|
// When this function returns non-empty string, that means parallel optimizer is applied on this parameter. |
|
|
|
std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, int> &res) { |
|
|
|
MS_EXCEPTION_IF_NULL(parameter); |
|
|
|
AbstractBasePtr abstract = parameter->abstract(); |
|
|
|
MS_EXCEPTION_IF_NULL(abstract); |
|
|
|
@@ -1417,26 +1440,40 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i |
|
|
|
if (distribute_operator == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; |
|
|
|
} |
|
|
|
|
|
|
|
if (IntToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) { |
|
|
|
MS_LOG(EXCEPTION) << "The index is out of range, index is " << res.second - 1 << ", vector size is " |
|
|
|
<< distribute_operator->inputs_tensor_info().size(); |
|
|
|
} |
|
|
|
TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(res.second - 1)]; |
|
|
|
TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); |
|
|
|
Shape slice_shape = tensor_layout.slice_shape().array(); |
|
|
|
std::string opt_shard_group; |
|
|
|
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); |
|
|
|
bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer(); |
|
|
|
Shape slice_shape = tensor_layout.slice_shape().array(); |
|
|
|
if (enable_parallel_optimizer) { |
|
|
|
if (!ParameterRequireGrad(parameter)) { |
|
|
|
// only trainable parameters need parallel optimizer |
|
|
|
MS_LOG(INFO) << "Parallel optimizer is no need for " << parameter->ToString(); |
|
|
|
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter."; |
|
|
|
} else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) { |
|
|
|
// get a totally shard tensor slice shape if the weight is repeated on devices |
|
|
|
// and the shape of the first dimension could be divided |
|
|
|
// apply parallel optimizer on parameters |
|
|
|
ApplyParallelOptOnParam(&tensor_layout, distribute_operator, cnode, parameter, IntToSize(res.second)); |
|
|
|
// create communication group for allgather operator |
|
|
|
slice_shape = tensor_layout.opt_shard_slice_shape(); |
|
|
|
std::vector<Group> dev_group; |
|
|
|
if (distribute_operator->CreateGroupByTensorMap(tensor_layout.origin_tensor_map().array(), &dev_group) == |
|
|
|
Status::SUCCESS && |
|
|
|
!dev_group.empty()) { |
|
|
|
opt_shard_group = dev_group[0].name(); |
|
|
|
// set communication group in tensor layout for checkpoint saving |
|
|
|
tensor_layout.set_opt_shard_group(opt_shard_group); |
|
|
|
MS_LOG(INFO) << "Parallel optimizer: create group " << opt_shard_group << " for " << parameter->ToString() |
|
|
|
<< " success."; |
|
|
|
} else { |
|
|
|
MS_LOG(WARNING) << "Parallel optimizer: create group for " << parameter->ToString() << " failed."; |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << "'s shape does not satisfy the conditions."; |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape " |
|
|
|
@@ -1451,6 +1488,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i |
|
|
|
ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(parameter_ptr); |
|
|
|
parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout)); |
|
|
|
return opt_shard_group; |
|
|
|
} |
|
|
|
|
|
|
|
void CoverSliceShape(const FuncGraphPtr &root) { |
|
|
|
@@ -1460,14 +1498,18 @@ void CoverSliceShape(const FuncGraphPtr &root) { |
|
|
|
MS_EXCEPTION_IF_NULL(parameter->Shape()); |
|
|
|
auto iter = g_RefMap.find(parameter); |
|
|
|
if (iter != g_RefMap.end()) { |
|
|
|
SetParallelShape(parameter, g_RefMap[parameter]); |
|
|
|
std::string group = SetParallelShape(parameter, g_RefMap[parameter]); |
|
|
|
// find all forward nodes that use parameter in graphs and insert allgather if group is not empty |
|
|
|
ApplyParallelOptOnParam(root, parameter, group); |
|
|
|
continue; |
|
|
|
} |
|
|
|
std::pair<AnfNodePtr, int> res = FindSubGraph(root, parameter); |
|
|
|
if (res.first == nullptr) { |
|
|
|
MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape"; |
|
|
|
} else { |
|
|
|
SetParallelShape(parameter, res); |
|
|
|
std::string group = SetParallelShape(parameter, res); |
|
|
|
// find all forward nodes that use parameter in graphs and insert allgather if group is not empty |
|
|
|
ApplyParallelOptOnParam(root, parameter, group); |
|
|
|
MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); |
|
|
|
} |
|
|
|
} |
|
|
|
|