|
|
|
@@ -336,58 +336,63 @@ std::pair<OperatorInfoPtr, int> PipelineTransformer::GetOpInfo(const AnfNodePtr |
|
|
|
return std::make_pair(op_info, tensor_info_index); |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodeIndexSet PipelineTransformer::GetActualOpUsers(const AnfNodePtr &node, NodeUsersMap *node_users_map) { |
|
|
|
auto temp_users = (*node_users_map)[node]; |
|
|
|
auto temp_node = temp_users.front().first; |
|
|
|
if (IsPrimitiveCNode(temp_node, prim::kPrimLoad) || IsPrimitiveCNode(temp_node, prim::kPrimCast)) { |
|
|
|
return GetActualOpUsers(temp_node, node_users_map); |
|
|
|
} |
|
|
|
return temp_users; |
|
|
|
} |
|
|
|
|
|
|
|
std::pair<OperatorInfoPtr, int> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto node_users_map = manager_->node_users(); |
|
|
|
auto node_users = node_users_map[node]; |
|
|
|
for (auto &node_user : node_users) { |
|
|
|
auto load = node_user.first->cast<CNodePtr>(); |
|
|
|
if (IsPrimitiveCNode(load, prim::kPrimLoad)) { |
|
|
|
node_users = node_users_map[load]; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
for (auto &user_pair : node_users) { |
|
|
|
auto user_node = user_pair.first->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(user_node); |
|
|
|
auto user_node_graph = user_node->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(user_node_graph); |
|
|
|
if (user_node_graph->stage() == -1) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto care_node = user_node; |
|
|
|
auto index = user_pair.second; |
|
|
|
if (IsValueNode<FuncGraph>(user_node->input(0))) { |
|
|
|
auto graph = GetValueNode<FuncGraphPtr>(user_node->input(0)); |
|
|
|
auto temp_params = graph->parameters(); |
|
|
|
if (temp_params.size() < IntToSize(user_pair.second)) { |
|
|
|
MS_LOG(EXCEPTION) << "parameter:" << node->DebugString() << " out of graph: " << graph->ToString() |
|
|
|
<< "'s range."; |
|
|
|
auto load_users = GetActualOpUsers(node_user.first, &node_users_map); |
|
|
|
for (auto &user_pair : load_users) { |
|
|
|
auto user_node = user_pair.first->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(user_node); |
|
|
|
auto user_node_graph = user_node->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(user_node_graph); |
|
|
|
if (user_node_graph->stage() == -1) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto temp_param = temp_params[user_pair.second - 1]; |
|
|
|
auto temp_users = node_users_map[temp_param]; |
|
|
|
for (auto &temp_user : temp_users) { |
|
|
|
auto load_temp = temp_user.first->cast<CNodePtr>(); |
|
|
|
if (IsPrimitiveCNode(load_temp, prim::kPrimLoad)) { |
|
|
|
temp_users = node_users_map[load_temp]; |
|
|
|
auto care_node = user_node; |
|
|
|
auto index = user_pair.second; |
|
|
|
if (IsValueNode<FuncGraph>(user_node->input(0))) { |
|
|
|
auto graph = GetValueNode<FuncGraphPtr>(user_node->input(0)); |
|
|
|
auto temp_params = graph->parameters(); |
|
|
|
if (temp_params.size() < IntToSize(user_pair.second)) { |
|
|
|
MS_LOG(EXCEPTION) << "parameter:" << node->DebugString() << " out of graph: " << graph->ToString() |
|
|
|
<< "'s range."; |
|
|
|
} |
|
|
|
auto temp_param = temp_params[user_pair.second - 1]; |
|
|
|
auto temp_users = node_users_map[temp_param]; |
|
|
|
for (auto &temp_user : temp_users) { |
|
|
|
auto load_temp = temp_user.first->cast<CNodePtr>(); |
|
|
|
if (IsPrimitiveCNode(load_temp, prim::kPrimLoad)) { |
|
|
|
temp_users = node_users_map[load_temp]; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
for (auto &temp_pair : temp_users) { |
|
|
|
auto temp_cnode = temp_pair.first->cast<CNodePtr>(); |
|
|
|
if (!IsPipelineCareNode(temp_cnode)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
care_node = temp_cnode; |
|
|
|
index = temp_pair.second; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
for (auto &temp_pair : temp_users) { |
|
|
|
auto temp_cnode = temp_pair.first->cast<CNodePtr>(); |
|
|
|
if (!IsPipelineCareNode(temp_cnode)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
care_node = temp_cnode; |
|
|
|
index = temp_pair.second; |
|
|
|
break; |
|
|
|
if (!IsPipelineCareNode(care_node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto op_info = CreateOpInfo(care_node); |
|
|
|
return std::make_pair(op_info, index - 1); |
|
|
|
} |
|
|
|
if (!IsPipelineCareNode(care_node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto op_info = CreateOpInfo(care_node); |
|
|
|
return std::make_pair(op_info, index - 1); |
|
|
|
} |
|
|
|
return std::make_pair(nullptr, 0); |
|
|
|
} |
|
|
|
@@ -432,7 +437,8 @@ std::vector<AnfNodePtr> PipelineTransformer::HandleSharedParameter() { |
|
|
|
if (receive) { |
|
|
|
manager_->SetEdge(node, user.second, receive); |
|
|
|
} else { |
|
|
|
auto recv = InsertReceive(main_graph_, parameter, node, user.second, stage_, *parameter_stage.begin(), micro); |
|
|
|
auto recv = InsertReceive(main_graph_, parameter, node, user.second, stage_, *parameter_stage.begin(), micro, |
|
|
|
parameter); |
|
|
|
recvs.push_back(recv); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -594,7 +600,8 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod |
|
|
|
|
|
|
|
AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, |
|
|
|
const AnfNodePtr &use_node, int index, int64_t user_node_stage, |
|
|
|
int64_t node_stage, const ValuePtr &value) { |
|
|
|
int64_t node_stage, const ValuePtr &value, |
|
|
|
const AnfNodePtr &graph_param) { |
|
|
|
auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_; |
|
|
|
int64_t recv_tag; |
|
|
|
if (recv_tag_map.find(src_rank) != recv_tag_map.end()) { |
|
|
|
@@ -610,18 +617,13 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A |
|
|
|
bool is_param = true; |
|
|
|
TensorInfo tensor_info; |
|
|
|
if (node->isa<Parameter>()) { |
|
|
|
op_info_pair = GetParameterPair(node); |
|
|
|
op_info_pair = GetParameterPair(graph_param); |
|
|
|
tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second)); |
|
|
|
} else { |
|
|
|
auto care_node = FindPipelineCareNode(node); |
|
|
|
if (care_node->isa<Parameter>()) { |
|
|
|
op_info_pair = GetParameterPair(care_node); |
|
|
|
tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second)); |
|
|
|
} else { |
|
|
|
op_info_pair = GetOpInfo(care_node); |
|
|
|
tensor_info = op_info_pair.first->outputs_tensor_info().at(IntToSize(op_info_pair.second)); |
|
|
|
is_param = false; |
|
|
|
} |
|
|
|
op_info_pair = GetOpInfo(care_node); |
|
|
|
tensor_info = op_info_pair.first->outputs_tensor_info().at(IntToSize(op_info_pair.second)); |
|
|
|
is_param = false; |
|
|
|
} |
|
|
|
auto tensor_layout = tensor_info.tensor_layout(); |
|
|
|
Shape slice_shape = tensor_info.slice_shape(); |
|
|
|
@@ -694,6 +696,119 @@ AnfNodePtr PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t stage, con |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr PipelineTransformer::ActualOp(const AnfNodePtr &node) { |
|
|
|
// skip some virtual op like:Depend, Load, Cast. |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimCast) || |
|
|
|
IsPrimitiveCNode(node, prim::kPrimLoad)) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
return ActualOp(cnode->input(1)); |
|
|
|
} |
|
|
|
return node; |
|
|
|
} |
|
|
|
|
|
|
|
bool PipelineTransformer::IsParameterGraph(const AnfNodePtr &node) { |
|
|
|
// ParameterGraph: graph which return a parameter |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
// parameter_graph->return->load->graph |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimLoad)) { |
|
|
|
auto graph_cnode = cnode->input(1)->cast<CNodePtr>(); |
|
|
|
if (!graph_cnode) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (!IsValueNode<FuncGraph>(graph_cnode->input(0))) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
// Now load's input must be a parameter |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
// parameter_graph->return->graph |
|
|
|
if (!IsValueNode<FuncGraph>(cnode->input(0))) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0)); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto graph_out = graph->output(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph_out); |
|
|
|
auto actual_op = ActualOp(graph_out); |
|
|
|
MS_EXCEPTION_IF_NULL(actual_op); |
|
|
|
if (actual_op->isa<Parameter>()) { |
|
|
|
auto parameter_list = graph->parameters(); |
|
|
|
// parameter_graph->parameter->return->graph |
|
|
|
auto parameter_iter = std::find(parameter_list.begin(), parameter_list.end(), actual_op); |
|
|
|
if (parameter_iter == parameter_list.end()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
// parameter->graph->return->graph |
|
|
|
auto pos = std::distance(parameter_list.begin(), parameter_iter); |
|
|
|
if (!cnode->input(pos + 1)->isa<Parameter>()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage, |
|
|
|
int64_t user_stage, const ValuePtr µ, size_t pos, |
|
|
|
const std::vector<AnfNodePtr> ops) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
AnfNodePtr argument; |
|
|
|
AnfNodePtr parameter; |
|
|
|
FuncGraphPtr graph; |
|
|
|
// parameter_graph->load->graph |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimLoad)) { |
|
|
|
auto graph_cnode = cnode->input(1)->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph_cnode); |
|
|
|
graph = GetValueNode<FuncGraphPtr>(graph_cnode->input(0)); |
|
|
|
} else { |
|
|
|
graph = GetValueNode<FuncGraphPtr>(cnode->input(0)); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
|
|
auto graph_out = ActualOp(graph->output()); |
|
|
|
MS_EXCEPTION_IF_NULL(graph_out); |
|
|
|
auto parameter_list = graph->parameters(); |
|
|
|
auto param_iter = std::find(parameter_list.begin(), parameter_list.end(), graph_out); |
|
|
|
auto use_cnode = use_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(use_cnode); |
|
|
|
if (!IsValueNode<FuncGraph>(use_cnode->input(0))) { |
|
|
|
MS_LOG(EXCEPTION) << "Parameter must be used by a graph, but got: " << use_cnode->DebugString(); |
|
|
|
} |
|
|
|
auto use_graph = GetValueNode<FuncGraphPtr>(use_cnode->input(0)); |
|
|
|
auto use_parameter_list = use_graph->parameters(); |
|
|
|
parameter = use_parameter_list.at(pos - 1); |
|
|
|
// argument->load->graph |
|
|
|
if (param_iter == parameter_list.end()) { |
|
|
|
argument = graph_out; |
|
|
|
} else { |
|
|
|
auto param_pos = std::distance(parameter_list.begin(), param_iter); |
|
|
|
argument = cnode->input(param_pos + 1); |
|
|
|
} |
|
|
|
|
|
|
|
// insert receive |
|
|
|
if (stage_ == user_stage) { |
|
|
|
auto recv = Reuse(argument, stage, ops, SRC_RANK); |
|
|
|
if (recv) { |
|
|
|
manager_->SetEdge(use_node, pos, recv); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return InsertReceive(main_graph_, argument, use_node, pos, user_stage, stage, micro, parameter); |
|
|
|
} |
|
|
|
// insert send |
|
|
|
if (Reuse(argument, user_stage, ops, DEST_RANK)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto send_out = InsertSend(main_graph_, argument, user_stage, stage_, micro); |
|
|
|
send_out.depend->set_user_data<Type>(DTYPE, send_out.type); |
|
|
|
send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape); |
|
|
|
return send_out.depend; |
|
|
|
} |
|
|
|
|
|
|
|
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer::CutBorder(const FuncGraphPtr &graph) { |
|
|
|
OperatorAttrs depend_attrs; |
|
|
|
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND); |
|
|
|
@@ -708,7 +823,7 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer: |
|
|
|
MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num; |
|
|
|
} |
|
|
|
for (auto &node : all_nodes) { |
|
|
|
if (!node->isa<CNode>() || node->stage() == -1) { |
|
|
|
if (!node->isa<CNode>() || node->stage() == -1 || IsPrimitiveCNode(node, prim::kPrimUpdateState)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto node_users = manager_->node_users()[node]; |
|
|
|
@@ -727,6 +842,15 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer: |
|
|
|
} |
|
|
|
if (node_stage < user_node_stage) { |
|
|
|
if (node_stage == stage_) { |
|
|
|
if (IsParameterGraph(node)) { |
|
|
|
auto send_depend = |
|
|
|
HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, user_pair.second, send_ops); |
|
|
|
if (!send_depend) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
send_ops.insert(send_ops.begin(), send_depend); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (Reuse(node, user_node_stage, send_ops, DEST_RANK)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
@@ -737,8 +861,18 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer: |
|
|
|
send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape); |
|
|
|
} else { |
|
|
|
if (!receive) { |
|
|
|
receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro); |
|
|
|
receive_ops.push_back(receive); |
|
|
|
if (IsParameterGraph(node)) { |
|
|
|
receive = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, user_pair.second, |
|
|
|
receive_ops); |
|
|
|
if (!receive) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
receive_ops.push_back(receive); |
|
|
|
} else { |
|
|
|
receive = |
|
|
|
InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro, node); |
|
|
|
receive_ops.push_back(receive); |
|
|
|
} |
|
|
|
} else { |
|
|
|
manager_->SetEdge(user_node, user_pair.second, receive); |
|
|
|
} |
|
|
|
|