|
|
|
@@ -28,6 +28,8 @@ |
|
|
|
#include "frontend/parallel/context.h" |
|
|
|
#include "frontend/parallel/step_parallel.h" |
|
|
|
#include "frontend/parallel/node_check.h" |
|
|
|
#include "ir/anf.h" |
|
|
|
#include "base/core_ops.h" |
|
|
|
#include "utils/comm_manager.h" |
|
|
|
#include "utils/ms_context.h" |
|
|
|
|
|
|
|
@@ -136,6 +138,11 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) { |
|
|
|
|
|
|
|
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
// handle send/recv a parameter |
|
|
|
if (node->isa<Parameter>()) { |
|
|
|
MS_LOG(INFO) << "parameter: " << node->ToString() << " need to be send/recv."; |
|
|
|
return std::make_pair(nullptr, nullptr); |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
OperatorInfoPtr op_info = nullptr; |
|
|
|
@@ -170,6 +177,23 @@ std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const A |
|
|
|
return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info)); |
|
|
|
} |
|
|
|
|
|
|
|
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto node_users = manager_->node_users()[node]; |
|
|
|
for (auto &user_pair : node_users) { |
|
|
|
auto user_node = user_pair.first->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(user_node); |
|
|
|
if (!IsPipelineCareNode(user_node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto op_info = CreateOpInfo(user_node); |
|
|
|
MS_EXCEPTION_IF_NULL(op_info); |
|
|
|
auto tensor_info = op_info->inputs_tensor_info()[IntToSize(user_pair.second) - 1]; |
|
|
|
return std::make_pair(nullptr, std::make_shared<TensorInfo>(tensor_info)); |
|
|
|
} |
|
|
|
return std::make_pair(nullptr, nullptr); |
|
|
|
} |
|
|
|
|
|
|
|
void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) { |
|
|
|
auto need_coloring = true; |
|
|
|
while (need_coloring) { |
|
|
|
@@ -240,6 +264,7 @@ void PipelineTransformer::HandleSharedParameter() { |
|
|
|
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, ""); |
|
|
|
std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, make_tuple}; |
|
|
|
auto depend = graph->NewCNode(depend_input); |
|
|
|
depend->set_abstract(parameter->abstract()); |
|
|
|
manager_->SetEdge(node, user.second, depend); |
|
|
|
break; |
|
|
|
} else { |
|
|
|
@@ -301,7 +326,12 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod |
|
|
|
auto send_op = CreatOpInstance(attrs, SEND, "send"); |
|
|
|
auto send_node = NewValueNode(send_op); |
|
|
|
auto prim = GetValueNode<PrimitivePtr>(send_node); |
|
|
|
auto op_info_pair = GetOpInfo(parameter); |
|
|
|
std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair; |
|
|
|
if (parameter->isa<Parameter>()) { |
|
|
|
op_info_pair = GetParameterPair(parameter); |
|
|
|
} else { |
|
|
|
op_info_pair = GetOpInfo(parameter); |
|
|
|
} |
|
|
|
auto tensor_info = op_info_pair.second; |
|
|
|
MS_EXCEPTION_IF_NULL(tensor_info); |
|
|
|
auto slice_shape = tensor_info->slice_shape(); |
|
|
|
@@ -314,6 +344,8 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod |
|
|
|
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, "depend"); |
|
|
|
std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, send}; |
|
|
|
auto depend = graph->NewCNode(depend_input); |
|
|
|
auto abstract = parameter->abstract(); |
|
|
|
depend->set_abstract(abstract); |
|
|
|
SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend}; |
|
|
|
return send_out; |
|
|
|
} |
|
|
|
@@ -324,7 +356,12 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode |
|
|
|
recv_tag += 1; |
|
|
|
auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_; |
|
|
|
Attr attr_rank = std::make_pair("src_rank", MakeValue(src_rank)); |
|
|
|
auto op_info_pair = GetOpInfo(node); |
|
|
|
std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair; |
|
|
|
if (node->isa<Parameter>()) { |
|
|
|
op_info_pair = GetParameterPair(node); |
|
|
|
} else { |
|
|
|
op_info_pair = GetOpInfo(node); |
|
|
|
} |
|
|
|
auto tensor_info = op_info_pair.second; |
|
|
|
MS_EXCEPTION_IF_NULL(tensor_info); |
|
|
|
auto slice_shape = tensor_info->slice_shape(); |
|
|
|
@@ -333,12 +370,19 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode |
|
|
|
Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second); |
|
|
|
OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype}; |
|
|
|
auto recv_op = CreatOpInstance(attrs, RECEIVE, "recv"); |
|
|
|
std::vector<AnfNodePtr> recv_input = {NewValueNode(recv_op), virtual_param_}; |
|
|
|
std::vector<AnfNodePtr> recv_input; |
|
|
|
if (node->isa<Parameter>()) { |
|
|
|
recv_input = {NewValueNode(recv_op), node}; |
|
|
|
} else { |
|
|
|
recv_input = {NewValueNode(recv_op), virtual_param_}; |
|
|
|
} |
|
|
|
auto recv = graph->NewCNode(recv_input); |
|
|
|
auto node_abstract = node->abstract(); |
|
|
|
recv->set_abstract(node_abstract); |
|
|
|
recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_info->tensor_layout())); |
|
|
|
recv->set_user_data<OperatorInfo>(op_info_pair.first); |
|
|
|
if (op_info_pair.first != nullptr) { |
|
|
|
recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_info->tensor_layout())); |
|
|
|
recv->set_user_data<OperatorInfo>(op_info_pair.first); |
|
|
|
} |
|
|
|
manager_->SetEdge(use_node, index, recv); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -448,13 +492,6 @@ void PipelineTransformer::ElimGraphStage() { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool PipelineTransformer::IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { |
|
|
|
auto anf_node = cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(anf_node); |
|
|
|
auto prim = anf_node->value()->cast<PrimitivePtr>(); |
|
|
|
return (prim->name() == name); |
|
|
|
} |
|
|
|
|
|
|
|
std::pair<CNodePtr, FuncGraphPtr> PipelineTransformer::FindSensNode() { |
|
|
|
std::pair<CNodePtr, FuncGraphPtr> sens_graph_pair; |
|
|
|
CNodePtr sens_cnode; |
|
|
|
@@ -471,7 +508,7 @@ std::pair<CNodePtr, FuncGraphPtr> PipelineTransformer::FindSensNode() { |
|
|
|
} |
|
|
|
|
|
|
|
auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>(); |
|
|
|
if (!IsSomePrimitive(expect_tuple_getitem_cnode, TUPLE_GETITEM)) { |
|
|
|
if (!IsPrimitiveCNode(expect_tuple_getitem_cnode, prim::kPrimTupleGetItem)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto expect_anonymous = expect_tuple_getitem_cnode->input(1); |
|
|
|
@@ -484,7 +521,7 @@ std::pair<CNodePtr, FuncGraphPtr> PipelineTransformer::FindSensNode() { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto expect_j_cnode = expect_j->cast<CNodePtr>(); |
|
|
|
if (!IsSomePrimitive(expect_j_cnode, J)) { |
|
|
|
if (!IsPrimitiveCNode(expect_j_cnode, prim::kPrimJ)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1)); |
|
|
|
|