From 2aad57c595320da8a548105f411dc08e8538bb9d Mon Sep 17 00:00:00 2001 From: jojobugfree Date: Tue, 14 Apr 2020 14:57:23 +0800 Subject: [PATCH] getnext disable memory reuse --- mindspore/ccsrc/device/kernel_runtime.cc | 4 ++++ mindspore/ccsrc/session/anf_runtime_algorithm.cc | 5 +++++ mindspore/ccsrc/session/anf_runtime_algorithm.h | 2 ++ mindspore/ccsrc/utils/utils.h | 1 + 4 files changed, 12 insertions(+) diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index eebc650347..db79484f8c 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -355,6 +355,10 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in AssignCommunicationNodeOutputMem(flag, node); return; } + if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) { + MS_LOG(INFO) << "GetNext disable mem_reuse"; + flag = kDynamicMem; + } auto kernel_mod = AnfAlgo::GetKernelMod(node); MS_EXCEPTION_IF_NULL(kernel_mod); auto output_sizes = kernel_mod->GetOutputSizeList(); diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 893c379a07..29a27a65b1 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -825,5 +825,10 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { } return false; } + +bool AnfRuntimeAlgorithm::IsGetNext(const NotNull &node) { + auto kernel_name = AnfAlgo::GetCNodeName(node); + return kernel_name == kGetNextOpName; +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index 1a1d471b84..ab5a68db7f 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -31,6 +31,7 @@ #include "kernel/kernel.h" #include "kernel/kernel_build_info.h" #include "operator/ops.h" +#include "utils/contract.h" namespace mindspore { namespace session { @@ -169,6 +170,7 @@ class AnfRuntimeAlgorithm { // get real input index for some tbe ops which input order is different between me and tbe impl static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); static bool IsCommunicationOp(const AnfNodePtr &node); + static bool IsGetNext(const NotNull &node); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 39b4b7a160..e1df2a8d25 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -42,6 +42,7 @@ constexpr auto kBNGrad2OpName = "BNGrad2"; constexpr auto kBNGrad3OpName = "BNGrad3"; constexpr auto kClearZeroOpName = "ClearZero"; constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean"; +constexpr auto kGetNextOpName = "GetNext"; constexpr auto kAllReduceOpName = "AllReduce"; constexpr auto kAllGatherOpName = "AllGather"; constexpr auto kBroadcastOpName = "Broadcast";