Browse Source

getnext disable memory reuse

tags/v0.2.0-alpha
jojobugfree 5 years ago
parent
commit
2aad57c595
4 changed files with 12 additions and 0 deletions
  1. +4
    -0
      mindspore/ccsrc/device/kernel_runtime.cc
  2. +5
    -0
      mindspore/ccsrc/session/anf_runtime_algorithm.cc
  3. +2
    -0
      mindspore/ccsrc/session/anf_runtime_algorithm.h
  4. +1
    -0
      mindspore/ccsrc/utils/utils.h

+ 4
- 0
mindspore/ccsrc/device/kernel_runtime.cc View File

@@ -355,6 +355,10 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in
AssignCommunicationNodeOutputMem(flag, node);
return;
}
if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) {
MS_LOG(INFO) << "GetNext disable mem_reuse";
flag = kDynamicMem;
}
auto kernel_mod = AnfAlgo::GetKernelMod(node);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto output_sizes = kernel_mod->GetOutputSizeList();


+ 5
- 0
mindspore/ccsrc/session/anf_runtime_algorithm.cc View File

@@ -825,5 +825,10 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
}
return false;
}

bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
auto kernel_name = AnfAlgo::GetCNodeName(node);
return kernel_name == kGetNextOpName;
}
} // namespace session
} // namespace mindspore

+ 2
- 0
mindspore/ccsrc/session/anf_runtime_algorithm.h View File

@@ -31,6 +31,7 @@
#include "kernel/kernel.h"
#include "kernel/kernel_build_info.h"
#include "operator/ops.h"
#include "utils/contract.h"

namespace mindspore {
namespace session {
@@ -169,6 +170,7 @@ class AnfRuntimeAlgorithm {
// get real input index for some tbe ops which input order is different between me and tbe impl
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
static bool IsCommunicationOp(const AnfNodePtr &node);
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;


+ 1
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -42,6 +42,7 @@ constexpr auto kBNGrad2OpName = "BNGrad2";
constexpr auto kBNGrad3OpName = "BNGrad3";
constexpr auto kClearZeroOpName = "ClearZero";
constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean";
constexpr auto kGetNextOpName = "GetNext";
constexpr auto kAllReduceOpName = "AllReduce";
constexpr auto kAllGatherOpName = "AllGather";
constexpr auto kBroadcastOpName = "Broadcast";


Loading…
Cancel
Save