| @@ -355,6 +355,10 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in | |||||
| AssignCommunicationNodeOutputMem(flag, node); | AssignCommunicationNodeOutputMem(flag, node); | ||||
| return; | return; | ||||
| } | } | ||||
| if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) { | |||||
| MS_LOG(INFO) << "GetNext disable mem_reuse"; | |||||
| flag = kDynamicMem; | |||||
| } | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(node); | auto kernel_mod = AnfAlgo::GetKernelMod(node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | MS_EXCEPTION_IF_NULL(kernel_mod); | ||||
| auto output_sizes = kernel_mod->GetOutputSizeList(); | auto output_sizes = kernel_mod->GetOutputSizeList(); | ||||
| @@ -825,5 +825,10 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) { | |||||
| auto kernel_name = AnfAlgo::GetCNodeName(node); | |||||
| return kernel_name == kGetNextOpName; | |||||
| } | |||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -31,6 +31,7 @@ | |||||
| #include "kernel/kernel.h" | #include "kernel/kernel.h" | ||||
| #include "kernel/kernel_build_info.h" | #include "kernel/kernel_build_info.h" | ||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "utils/contract.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| @@ -169,6 +170,7 @@ class AnfRuntimeAlgorithm { | |||||
| // get real input index for some tbe ops which input order is different between me and tbe impl | // get real input index for some tbe ops which input order is different between me and tbe impl | ||||
| static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); | static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); | ||||
| static bool IsCommunicationOp(const AnfNodePtr &node); | static bool IsCommunicationOp(const AnfNodePtr &node); | ||||
| static bool IsGetNext(const NotNull<AnfNodePtr> &node); | |||||
| }; | }; | ||||
| } // namespace session | } // namespace session | ||||
| using AnfAlgo = session::AnfRuntimeAlgorithm; | using AnfAlgo = session::AnfRuntimeAlgorithm; | ||||
| @@ -42,6 +42,7 @@ constexpr auto kBNGrad2OpName = "BNGrad2"; | |||||
| constexpr auto kBNGrad3OpName = "BNGrad3"; | constexpr auto kBNGrad3OpName = "BNGrad3"; | ||||
| constexpr auto kClearZeroOpName = "ClearZero"; | constexpr auto kClearZeroOpName = "ClearZero"; | ||||
| constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean"; | constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean"; | ||||
| constexpr auto kGetNextOpName = "GetNext"; | |||||
| constexpr auto kAllReduceOpName = "AllReduce"; | constexpr auto kAllReduceOpName = "AllReduce"; | ||||
| constexpr auto kAllGatherOpName = "AllGather"; | constexpr auto kAllGatherOpName = "AllGather"; | ||||
| constexpr auto kBroadcastOpName = "Broadcast"; | constexpr auto kBroadcastOpName = "Broadcast"; | ||||