| @@ -422,11 +422,14 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t> | |||||
| MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name; | MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name; | ||||
| reg_exist = false; | reg_exist = false; | ||||
| } | } | ||||
| auto ms_context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(ms_context); | |||||
| if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) { | if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) { | ||||
| reg_exist = false; | |||||
| if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kCPUDevice) { | |||||
| reg_exist = false; | |||||
| } | |||||
| } | } | ||||
| if (op_run_info->op_name == prim::kPrimGatherD->name()) { | if (op_run_info->op_name == prim::kPrimGatherD->name()) { | ||||
| auto ms_context = MsContext::GetInstance(); | |||||
| // Gather op needs converting const input to attr on GPU device | // Gather op needs converting const input to attr on GPU device | ||||
| if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) { | if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) { | ||||
| reg_exist = false; | reg_exist = false; | ||||