|
|
|
@@ -422,11 +422,14 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t> |
|
|
|
MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name; |
|
|
|
reg_exist = false; |
|
|
|
} |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(ms_context); |
|
|
|
if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) { |
|
|
|
reg_exist = false; |
|
|
|
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kCPUDevice) { |
|
|
|
reg_exist = false; |
|
|
|
} |
|
|
|
} |
|
|
|
if (op_run_info->op_name == prim::kPrimGatherD->name()) { |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
// Gather op needs converting const input to attr on GPU device |
|
|
|
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) { |
|
|
|
reg_exist = false; |
|
|
|
|