|
|
|
@@ -41,15 +41,18 @@ bool IsParameterOrValueNode(const AnfNodePtr &node) { |
|
|
|
} |
|
|
|
|
|
|
|
// NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i) |
|
|
|
bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users) { |
|
|
|
bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users, const CNodePtr &known_user, |
|
|
|
size_t known_index) { |
|
|
|
if (node_users.size() == 1) { |
|
|
|
MS_LOG(INFO) << "This node only used once, no need to insert tensormove node."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
for (const auto &node_pair : node_users) { |
|
|
|
auto node = node_pair.first; |
|
|
|
if (AnfAlgo::IsRealKernel(node) && !AnfAlgo::IsCommunicationOp(node)) { |
|
|
|
MS_LOG(INFO) << "This node only used other real kernel: " << node->fullname_with_scope(); |
|
|
|
auto &node = node_pair.first; |
|
|
|
size_t idx = IntToSize(node_pair.second); |
|
|
|
if (AnfAlgo::IsRealKernel(node) && !(known_user == node && known_index == idx)) { |
|
|
|
MS_LOG(INFO) << "User " << node->DebugString() << " idx " << idx << " is real kernel and diff with known " |
|
|
|
<< known_user->DebugString() << " idx " << known_index; |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -58,11 +61,13 @@ bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users) { |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
bool InsertTensorMoveForHcclOp::NeedInsertTensorMove(const FuncGraphPtr &graph, const AnfNodePtr &input, |
|
|
|
const CNodePtr &cur_node) const { |
|
|
|
bool InsertTensorMoveForHcclOp::NeedInsertTensorMove(const FuncGraphPtr &graph, const CNodePtr &cur_node, |
|
|
|
size_t input_idx) const { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(input); |
|
|
|
MS_EXCEPTION_IF_NULL(cur_node); |
|
|
|
auto input = cur_node->input(input_idx); |
|
|
|
MS_EXCEPTION_IF_NULL(input); |
|
|
|
|
|
|
|
if (IsPrimitiveCNode(cur_node, prim::kPrimReceive)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
@@ -93,7 +98,7 @@ bool InsertTensorMoveForHcclOp::NeedInsertTensorMove(const FuncGraphPtr &graph, |
|
|
|
MS_LOG(EXCEPTION) << "node has no output in manager" |
|
|
|
<< " trace: " << trace::DumpSourceLines(input); |
|
|
|
} |
|
|
|
if (IsNodeOutPutUsedByOtherRealKernel(iter->second)) { |
|
|
|
if (IsNodeOutPutUsedByOtherRealKernel(iter->second, cur_node, input_idx)) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -107,7 +112,7 @@ void InsertTensorMoveForHcclOp::InsertTensorMove(const FuncGraphPtr &graph, cons |
|
|
|
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)}; |
|
|
|
for (size_t i = 1; i < hccl_node->size(); ++i) { |
|
|
|
auto input = hccl_node->input(i); |
|
|
|
if (NeedInsertTensorMove(graph, input, hccl_node)) { |
|
|
|
if (NeedInsertTensorMove(graph, hccl_node, i)) { |
|
|
|
auto tensor_move = CreateTensorMoveOp(graph, input); |
|
|
|
if (tensor_move == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Create tensor_move op failed."; |
|
|
|
|