Browse Source

fix ps cache check cnode

tags/v1.1.0
limingqi107 5 years ago
parent
commit
560aa11b5f
3 changed files with 13 additions and 12 deletions
  1. +3
    -2
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
  2. +1
    -1
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
  3. +9
    -9
      mindspore/ccsrc/runtime/device/kernel_runtime.cc

+ 3
- 2
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -476,7 +476,8 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
return format;
}

KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) {
KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx,
bool visit_nop_node) {
MS_EXCEPTION_IF_NULL(anf_node);
if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."
@@ -484,7 +485,7 @@ KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_nod
}
auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
MS_EXCEPTION_IF_NULL(input_node);
return VisitKernelWithReturnType(input_node, 0);
return VisitKernelWithReturnType(input_node, 0, visit_nop_node);
}

std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {


+ 1
- 1
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h View File

@@ -112,7 +112,7 @@ class AnfRuntimeAlgorithm {
// get input format select of anf node
static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx);
// get prev node output width output index
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx);
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool visit_nop_node = false);
// get output format from prev node,input_index is the input index of current node related to prev node
static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
// get reshape_type of from the output of input node.


+ 9
- 9
mindspore/ccsrc/runtime/device/kernel_runtime.cc View File

@@ -1044,8 +1044,8 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") {
continue;
}
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0);
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1);
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
MS_EXCEPTION_IF_NULL(input_param.first);
MS_EXCEPTION_IF_NULL(input_index.first);
auto param_name = input_param.first->fullname_with_scope();
@@ -1053,11 +1053,11 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
continue;
}
auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second);
while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true);
MS_EXCEPTION_IF_NULL(input_index.first);
}
if ((!input_index.first->isa<Parameter>()) && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) {
if (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) {
bool full_batch = parallel::ParallelContext::GetInstance()->full_batch();
if ((!full_batch) || (AnfAlgo::GetCNodeName(input_index.first) != "Minimum")) {
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope()
@@ -1085,13 +1085,13 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") {
continue;
}
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0);
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1);
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
MS_EXCEPTION_IF_NULL(input_param.first);
MS_EXCEPTION_IF_NULL(input_index.first);
auto param_name = input_param.first->fullname_with_scope();
while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second);
while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true);
MS_EXCEPTION_IF_NULL(input_index.first);
}
if (input_index.first == first_cache_input_index) {


Loading…
Cancel
Save