Browse Source

gpu kernel support flexible inputs

tags/v1.1.0
wilfChen 5 years ago
parent
commit
907ca43330
2 changed files with 8 additions and 5 deletions
  1. +7
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc
  2. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h

+ 7
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc View File

@@ -36,19 +36,20 @@ void GpuKernelFactory::Register(const std::string &kernel_name, const KernelAttr
map_kernel_name_to_creater_[kernel_name].emplace_back(kernel_attr, creater);
}

void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info,
bool GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info,
std::vector<std::pair<KernelAttr, GpuKernelCreater>> *iter_second,
size_t attr_index) {
if (kernel_info->GetInputNum() != iter_second->at(attr_index).first.GetInputSize()) {
if (!iter_second->at(attr_index).first.GetAllSame()) {
MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Input size is mismatching!";
return false;
}
}
if (kernel_info->GetOutputNum() != iter_second->at(attr_index).first.GetOutputSize()) {
if (!iter_second->at(attr_index).first.GetAllSame()) {
MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Output size is mismatching!";
return false;
}
}
return true;
}

std::string GpuKernelFactory::SupportedTypeList(const std::string &kernel_name) {
@@ -119,7 +120,9 @@ std::pair<bool, size_t> GpuKernelFactory::GpuKernelAttrCheck(const std::string &
}

for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) {
CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index);
if (!CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index)) {
continue;
}
bool flag = true;
auto attr_size = (&(iter->second))->at(attr_index).first.GetInputSize();
// data type matching check of all input parameters of kernel


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h View File

@@ -57,7 +57,7 @@ class GpuKernelFactory {
GpuKernelFactory &operator=(const GpuKernelFactory &);

std::pair<bool, size_t> GpuKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo *kernel_info);
void CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info,
bool CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info,
std::vector<std::pair<KernelAttr, GpuKernelCreater>> *iter_second, size_t attr_index);
// map to maintain kernel and creater, KernelAttr object and creater must be registered as a pair.
std::map<std::string, std::vector<std::pair<KernelAttr, GpuKernelCreater>>> map_kernel_name_to_creater_;


Loading…
Cancel
Save