|
|
|
@@ -98,16 +98,23 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con |
|
|
|
void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(hccl_node); |
|
|
|
if (hccl_node->size() != 2) { |
|
|
|
MS_LOG(INFO) << "node[" + AnfAlgo::GetCNodeName(hccl_node) + "]'s inputs size not equal 2"; |
|
|
|
return; |
|
|
|
bool has_insert_memcpy = false; |
|
|
|
AnfNodePtr memcpy_async = nullptr; |
|
|
|
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)}; |
|
|
|
for (size_t i = 1; i < hccl_node->size(); ++i) { |
|
|
|
auto input = hccl_node->input(i); |
|
|
|
if (NeedInsertMemcpy(graph, input)) { |
|
|
|
memcpy_async = CreateMemcpyAsyncOp(graph, input); |
|
|
|
has_insert_memcpy = true; |
|
|
|
new_inputs.push_back(memcpy_async); |
|
|
|
} else { |
|
|
|
new_inputs.push_back(input); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
auto input = hccl_node->input(1); |
|
|
|
if (NeedInsertMemcpy(graph, input)) { |
|
|
|
auto memcpy_async = CreateMemcpyAsyncOp(graph, input); |
|
|
|
if (has_insert_memcpy) { |
|
|
|
CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node); |
|
|
|
new_hccl_node->set_inputs({hccl_node->input(0), memcpy_async}); |
|
|
|
new_hccl_node->set_inputs(new_inputs); |
|
|
|
auto manager = graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node"; |
|
|
|
@@ -115,7 +122,9 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co |
|
|
|
MS_LOG(DEBUG) << "end replace"; |
|
|
|
|
|
|
|
// transer hccl op's control to the memcpy_async |
|
|
|
TransferControl(new_hccl_node, memcpy_async, graph); |
|
|
|
if (hccl_node->size() == 2) { |
|
|
|
TransferControl(new_hccl_node, memcpy_async, graph); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|