|
|
|
@@ -36,19 +36,20 @@ void GpuKernelFactory::Register(const std::string &kernel_name, const KernelAttr |
|
|
|
map_kernel_name_to_creater_[kernel_name].emplace_back(kernel_attr, creater); |
|
|
|
} |
|
|
|
|
|
|
|
void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info, |
|
|
|
bool GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info, |
|
|
|
std::vector<std::pair<KernelAttr, GpuKernelCreater>> *iter_second, |
|
|
|
size_t attr_index) { |
|
|
|
if (kernel_info->GetInputNum() != iter_second->at(attr_index).first.GetInputSize()) { |
|
|
|
if (!iter_second->at(attr_index).first.GetAllSame()) { |
|
|
|
MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Input size is mismatching!"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
if (kernel_info->GetOutputNum() != iter_second->at(attr_index).first.GetOutputSize()) { |
|
|
|
if (!iter_second->at(attr_index).first.GetAllSame()) { |
|
|
|
MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Output size is mismatching!"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
std::string GpuKernelFactory::SupportedTypeList(const std::string &kernel_name) { |
|
|
|
@@ -119,7 +120,9 @@ std::pair<bool, size_t> GpuKernelFactory::GpuKernelAttrCheck(const std::string & |
|
|
|
} |
|
|
|
|
|
|
|
for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) { |
|
|
|
CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index); |
|
|
|
if (!CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
bool flag = true; |
|
|
|
auto attr_size = (&(iter->second))->at(attr_index).first.GetInputSize(); |
|
|
|
// data type matching check of all input parameters of kernel |
|
|
|
|