|
|
|
@@ -23,6 +23,12 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace { |
|
|
|
bool InputIsParameterOrValueNode(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_with_index = AnfAlgo::VisitKernel(node, 0); |
|
|
|
return kernel_with_index.first->isa<Parameter>() || kernel_with_index.first->isa<ValueNode>(); |
|
|
|
} |
|
|
|
|
|
|
|
const AnfNodePtr AddMemcpyAsyncIfInputIsUsedByOthers(const FuncGraphPtr &graph, const CNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
@@ -39,7 +45,8 @@ const AnfNodePtr AddMemcpyAsyncIfInputIsUsedByOthers(const FuncGraphPtr &graph, |
|
|
|
if (manager->node_users().find(input) == manager->node_users().end()) { |
|
|
|
MS_LOG(EXCEPTION) << "node has no output in manager"; |
|
|
|
} |
|
|
|
if (manager->node_users()[input].size() > 1) { |
|
|
|
// when input is used by others or is a parameter or is a value node, insert a memcpy_async |
|
|
|
if (manager->node_users()[input].size() > 1 || InputIsParameterOrValueNode(input)) { |
|
|
|
replace = true; |
|
|
|
new_inputs.push_back(CreateMemcpyAsyncOp(graph, input)); |
|
|
|
} else { |
|
|
|
|