| @@ -23,6 +23,12 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| bool InputIsParameterOrValueNode(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto kernel_with_index = AnfAlgo::VisitKernel(node, 0); | |||||
| return kernel_with_index.first->isa<Parameter>() || kernel_with_index.first->isa<ValueNode>(); | |||||
| } | |||||
| const AnfNodePtr AddMemcpyAsyncIfInputIsUsedByOthers(const FuncGraphPtr &graph, const CNodePtr &node) { | const AnfNodePtr AddMemcpyAsyncIfInputIsUsedByOthers(const FuncGraphPtr &graph, const CNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| @@ -39,7 +45,8 @@ const AnfNodePtr AddMemcpyAsyncIfInputIsUsedByOthers(const FuncGraphPtr &graph, | |||||
| if (manager->node_users().find(input) == manager->node_users().end()) { | if (manager->node_users().find(input) == manager->node_users().end()) { | ||||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | MS_LOG(EXCEPTION) << "node has no output in manager"; | ||||
| } | } | ||||
| if (manager->node_users()[input].size() > 1) { | |||||
| // when input is used by others or is a parameter or is a value node, insert a memcpy_async | |||||
| if (manager->node_users()[input].size() > 1 || InputIsParameterOrValueNode(input)) { | |||||
| replace = true; | replace = true; | ||||
| new_inputs.push_back(CreateMemcpyAsyncOp(graph, input)); | new_inputs.push_back(CreateMemcpyAsyncOp(graph, input)); | ||||
| } else { | } else { | ||||