diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc index ee720c6e0a..4496bec023 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc @@ -36,19 +36,20 @@ void GpuKernelFactory::Register(const std::string &kernel_name, const KernelAttr map_kernel_name_to_creater_[kernel_name].emplace_back(kernel_attr, creater); } -void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info, +bool GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info, std::vector> *iter_second, size_t attr_index) { if (kernel_info->GetInputNum() != iter_second->at(attr_index).first.GetInputSize()) { if (!iter_second->at(attr_index).first.GetAllSame()) { - MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Input size is mismatching!"; + return false; } } if (kernel_info->GetOutputNum() != iter_second->at(attr_index).first.GetOutputSize()) { if (!iter_second->at(attr_index).first.GetAllSame()) { - MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Output size is mismatching!"; + return false; } } + return true; } std::string GpuKernelFactory::SupportedTypeList(const std::string &kernel_name) { @@ -119,7 +120,9 @@ std::pair GpuKernelFactory::GpuKernelAttrCheck(const std::string & } for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) { - CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index); + if (!CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index)) { + continue; + } bool flag = true; auto attr_size = (&(iter->second))->at(attr_index).first.GetInputSize(); // data type matching check of all input parameters of kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h index c4667c56c8..d49cad9920 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h @@ -57,7 +57,7 @@ class GpuKernelFactory { GpuKernelFactory &operator=(const GpuKernelFactory &); std::pair GpuKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo *kernel_info); - void CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info, + bool CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info, std::vector> *iter_second, size_t attr_index); // map to maintain kernel and creater, KernelAttr object and creater must be registered as a pair. std::map>> map_kernel_name_to_creater_;