|
|
|
@@ -40,6 +40,38 @@ bool IsParameterOrValueNode(const AnfNodePtr &node) { |
|
|
|
return real_node->isa<ValueNode>(); |
|
|
|
} |
|
|
|
|
|
|
|
void SetInput(const CNodePtr &control_depend, const int index, const FuncGraphPtr &graph, const CNodePtr &hccl_node, |
|
|
|
const std::vector<AnfNodePtr> &memcpy_async_list) { |
|
|
|
MS_EXCEPTION_IF_NULL(control_depend); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_node); |
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; |
|
|
|
make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end()); |
|
|
|
make_tuple_inputs.emplace_back(hccl_node); |
|
|
|
auto make_tuple = graph->NewCNode(make_tuple_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple); |
|
|
|
control_depend->set_input(IntToSize(index), make_tuple); |
|
|
|
} |
|
|
|
|
|
|
|
void DealControlForGetitem(const CNodePtr &tuple_getitem, const FuncGraphPtr &graph, const CNodePtr &hccl_node, |
|
|
|
const std::vector<AnfNodePtr> &memcpy_async_list) { |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem); |
|
|
|
auto manager = graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto &node_users = manager->node_users(); |
|
|
|
auto iter = node_users.find(tuple_getitem); |
|
|
|
if (iter == node_users.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "node has no output in manager"; |
|
|
|
} |
|
|
|
for (const auto &node_index : iter->second) { |
|
|
|
AnfNodePtr output = node_index.first; |
|
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
|
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { |
|
|
|
SetInput(output->cast<CNodePtr>(), node_index.second, graph, hccl_node, memcpy_async_list); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &memcpy_async_list, |
|
|
|
const FuncGraphPtr &graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_node); |
|
|
|
@@ -53,25 +85,13 @@ void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &m |
|
|
|
} |
|
|
|
// find hccl_node's output which is a control depend |
|
|
|
for (const auto &node_index : iter->second) { |
|
|
|
if (!AnfAlgo::CheckPrimitiveType(node_index.first, prim::kPrimControlDepend)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
CNodePtr control_depend = node_index.first->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(control_depend); |
|
|
|
std::vector<AnfNodePtr> new_inputs; |
|
|
|
for (size_t i = 0; i < control_depend->size(); ++i) { |
|
|
|
if (i == IntToSize(node_index.second)) { |
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; |
|
|
|
make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end()); |
|
|
|
make_tuple_inputs.emplace_back(hccl_node); |
|
|
|
auto make_tuple = graph->NewCNode(make_tuple_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple); |
|
|
|
new_inputs.push_back(make_tuple); |
|
|
|
} else { |
|
|
|
new_inputs.push_back(control_depend->input(i)); |
|
|
|
} |
|
|
|
AnfNodePtr output = node_index.first; |
|
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
|
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { |
|
|
|
SetInput(output->cast<CNodePtr>(), node_index.second, graph, hccl_node, memcpy_async_list); |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimTupleGetItem)) { |
|
|
|
DealControlForGetitem(output->cast<CNodePtr>(), graph, hccl_node, memcpy_async_list); |
|
|
|
} |
|
|
|
control_depend->set_inputs(new_inputs); |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
@@ -148,11 +168,10 @@ const AnfNodePtr InsertMemcpyAsyncForHcclOp::Process(const FuncGraphPtr &func_gr |
|
|
|
if (func_graph == nullptr || node == nullptr || !node->isa<CNode>()) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if (!AnfAlgo::IsCommunicationOp(node)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
InsertMemcpyAsync(func_graph, cnode); |
|
|
|
InsertMemcpyAsync(func_graph, node->cast<CNodePtr>()); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
|