|
|
|
@@ -286,7 +286,7 @@ void PipelineTransformer::HandleSharedParameter() { |
|
|
|
manager_->SetEdge(node, user.second, depend); |
|
|
|
break; |
|
|
|
} else { |
|
|
|
InsertReceive(graph, parameter, node, user.second, stage_, *parameter_stage.begin()); |
|
|
|
(void)InsertReceive(graph, parameter, node, user.second, stage_, *parameter_stage.begin()); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -403,8 +403,9 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod |
|
|
|
return send_out; |
|
|
|
} |
|
|
|
|
|
|
|
void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, |
|
|
|
int index, int64_t user_node_stage, int64_t node_stage) { |
|
|
|
AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, |
|
|
|
const AnfNodePtr &use_node, int index, int64_t user_node_stage, |
|
|
|
int64_t node_stage) { |
|
|
|
auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_; |
|
|
|
int64_t recv_tag; |
|
|
|
if (recv_tag_map.find(src_rank) != recv_tag_map.end()) { |
|
|
|
@@ -464,6 +465,28 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode |
|
|
|
recv->set_user_data<OperatorInfo>(op_info_pair.first); |
|
|
|
} |
|
|
|
manager_->SetEdge(use_node, index, recv); |
|
|
|
return recv; |
|
|
|
} |
|
|
|
|
|
|
|
bool PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t next_node_stage, int64_t node_stage, |
|
|
|
const std::vector<AnfNodePtr> &out_input) { |
|
|
|
auto node_users = manager_->node_users()[node]; |
|
|
|
auto dest_rank = global_rank_ + (next_node_stage - node_stage) * per_stage_rank_num_; |
|
|
|
for (auto &depend : out_input) { |
|
|
|
if (!IsPrimitiveCNode(depend, prim::kPrimDepend)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto cnode = depend->cast<CNodePtr>(); |
|
|
|
if (cnode->input(1) == node) { |
|
|
|
auto send_cnode = cnode->input(2)->cast<CNodePtr>(); |
|
|
|
auto prim = GetValueNode<PrimitivePtr>(send_cnode->input(0)); |
|
|
|
auto dest_rank_send = GetValue<int64_t>(prim->GetAttr("dest_rank")); |
|
|
|
if (dest_rank_send == dest_rank) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
std::pair<bool, int64_t> PipelineTransformer::IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) { |
|
|
|
@@ -496,6 +519,7 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) { |
|
|
|
auto shared_min_tag_pair = IsSharedNode(node, node_users); |
|
|
|
auto is_shared = shared_min_tag_pair.first; |
|
|
|
auto min_tag = shared_min_tag_pair.second; |
|
|
|
AnfNodePtr receive = nullptr; |
|
|
|
for (auto &user_pair : node_users) { |
|
|
|
auto user_node = user_pair.first; |
|
|
|
auto node_stage = node->stage(); |
|
|
|
@@ -508,18 +532,19 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (node_stage == stage_) { |
|
|
|
if (Reuse(node, user_node_stage, node_stage, out_input)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto send_out = InsertSend(graph, node, user_node_stage, node_stage); |
|
|
|
out_input.insert(out_input.begin() + 1, send_out.depend); |
|
|
|
type_ptr_ = send_out.type; |
|
|
|
shape_ = send_out.shape; |
|
|
|
} else { |
|
|
|
InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage); |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (node_stage == user_node_stage) { |
|
|
|
if (is_shared && (min_tag != node_stage)) { |
|
|
|
InsertReceive(graph, node, user_node, user_pair.second, stage_, min_tag); |
|
|
|
if (!receive) { |
|
|
|
receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage); |
|
|
|
} else { |
|
|
|
manager_->SetEdge(user_node, user_pair.second, receive); |
|
|
|
} |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
|