diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index d9f476f0ca..390f9cdf14 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -422,11 +422,14 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name; reg_exist = false; } + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) { - reg_exist = false; + if (ms_context->get_param(MS_CTX_DEVICE_TARGET) != kCPUDevice) { + reg_exist = false; + } } if (op_run_info->op_name == prim::kPrimGatherD->name()) { - auto ms_context = MsContext::GetInstance(); // Gather op needs converting const input to attr on GPU device if (ms_context->get_param(MS_CTX_DEVICE_TARGET) != kGPUDevice) { reg_exist = false;