|
|
|
@@ -1044,8 +1044,8 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, |
|
|
|
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0); |
|
|
|
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1); |
|
|
|
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true); |
|
|
|
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true); |
|
|
|
MS_EXCEPTION_IF_NULL(input_param.first); |
|
|
|
MS_EXCEPTION_IF_NULL(input_index.first); |
|
|
|
auto param_name = input_param.first->fullname_with_scope(); |
|
|
|
@@ -1053,11 +1053,11 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto size = ps::ps_cache_instance.QueryHashTableSize(param_name); |
|
|
|
while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) { |
|
|
|
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second); |
|
|
|
while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) { |
|
|
|
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true); |
|
|
|
MS_EXCEPTION_IF_NULL(input_index.first); |
|
|
|
} |
|
|
|
if ((!input_index.first->isa<Parameter>()) && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) { |
|
|
|
if (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) { |
|
|
|
bool full_batch = parallel::ParallelContext::GetInstance()->full_batch(); |
|
|
|
if ((!full_batch) || (AnfAlgo::GetCNodeName(input_index.first) != "Minimum")) { |
|
|
|
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() |
|
|
|
@@ -1085,13 +1085,13 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g |
|
|
|
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0); |
|
|
|
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1); |
|
|
|
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true); |
|
|
|
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true); |
|
|
|
MS_EXCEPTION_IF_NULL(input_param.first); |
|
|
|
MS_EXCEPTION_IF_NULL(input_index.first); |
|
|
|
auto param_name = input_param.first->fullname_with_scope(); |
|
|
|
while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) { |
|
|
|
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second); |
|
|
|
while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) { |
|
|
|
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true); |
|
|
|
MS_EXCEPTION_IF_NULL(input_index.first); |
|
|
|
} |
|
|
|
if (input_index.first == first_cache_input_index) { |
|
|
|
|