|
|
@@ -33,7 +33,21 @@ using mindspore::schema::PrimitiveType_MAX; |
|
|
using mindspore::schema::PrimitiveType_MIN; |
|
|
using mindspore::schema::PrimitiveType_MIN; |
|
|
|
|
|
|
|
|
namespace mindspore::lite { |
|
|
namespace mindspore::lite { |
|
|
KernelRegistry::KernelRegistry() {} |
|
|
|
|
|
|
|
|
KernelRegistry::KernelRegistry() { |
|
|
|
|
|
device_type_length_ = kKernelArch_MAX - kKernelArch_MIN + 1; |
|
|
|
|
|
data_type_length_ = kNumberTypeEnd - kNumberTypeBegin + 1; |
|
|
|
|
|
op_type_length_ = PrimitiveType_MAX - PrimitiveType_MIN + 1; |
|
|
|
|
|
// malloc an array contain creator functions of kernel |
|
|
|
|
|
auto total_len = device_type_length_ * data_type_length_ * op_type_length_; |
|
|
|
|
|
creator_arrays_ = (kernel::KernelCreator *)malloc(total_len * sizeof(kernel::KernelCreator)); |
|
|
|
|
|
if (creator_arrays_ == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "malloc creator_arrays_ failed."; |
|
|
|
|
|
} else { |
|
|
|
|
|
for (int i = 0; i < total_len; ++i) { |
|
|
|
|
|
creator_arrays_[i] = nullptr; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
KernelRegistry::~KernelRegistry() { FreeCreatorArray(); } |
|
|
KernelRegistry::~KernelRegistry() { FreeCreatorArray(); } |
|
|
|
|
|
|
|
|
@@ -43,25 +57,6 @@ KernelRegistry *KernelRegistry::GetInstance() { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
int KernelRegistry::Init() { |
|
|
int KernelRegistry::Init() { |
|
|
lock_.lock(); |
|
|
|
|
|
if (creator_arrays_ != nullptr) { |
|
|
|
|
|
lock_.unlock(); |
|
|
|
|
|
return RET_OK; |
|
|
|
|
|
} |
|
|
|
|
|
device_type_length_ = kKernelArch_MAX - kKernelArch_MIN; |
|
|
|
|
|
data_type_length_ = kNumberTypeEnd - kNumberTypeBegin; |
|
|
|
|
|
op_type_length_ = PrimitiveType_MAX - PrimitiveType_MIN; |
|
|
|
|
|
// malloc an array contain creator functions of kernel |
|
|
|
|
|
auto total_len = device_type_length_ * data_type_length_ * op_type_length_; |
|
|
|
|
|
creator_arrays_ = (kernel::KernelCreator *)malloc(total_len * sizeof(kernel::KernelCreator)); |
|
|
|
|
|
if (creator_arrays_ == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "malloc creator_arrays_ failed."; |
|
|
|
|
|
lock_.unlock(); |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
for (int i = 0; i < total_len; ++i) { |
|
|
|
|
|
creator_arrays_[i] = nullptr; |
|
|
|
|
|
} |
|
|
|
|
|
#ifdef ENABLE_ARM64 |
|
|
#ifdef ENABLE_ARM64 |
|
|
void *optimized_lib_handler = OptimizeModule::GetInstance()->optimized_op_handler_; |
|
|
void *optimized_lib_handler = OptimizeModule::GetInstance()->optimized_op_handler_; |
|
|
if (optimized_lib_handler != nullptr) { |
|
|
if (optimized_lib_handler != nullptr) { |
|
|
@@ -70,7 +65,6 @@ int KernelRegistry::Init() { |
|
|
MS_LOG(INFO) << "load optimize lib failed."; |
|
|
MS_LOG(INFO) << "load optimize lib failed."; |
|
|
} |
|
|
} |
|
|
#endif |
|
|
#endif |
|
|
lock_.unlock(); |
|
|
|
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -82,6 +76,10 @@ void KernelRegistry::FreeCreatorArray() { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { |
|
|
kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { |
|
|
|
|
|
if (creator_arrays_ == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "Creator func array is null."; |
|
|
|
|
|
return nullptr; |
|
|
|
|
|
} |
|
|
int index = GetCreatorFuncIndex(desc); |
|
|
int index = GetCreatorFuncIndex(desc); |
|
|
auto it = creator_arrays_[index]; |
|
|
auto it = creator_arrays_[index]; |
|
|
if (it != nullptr) { |
|
|
if (it != nullptr) { |
|
|
@@ -100,12 +98,20 @@ int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void KernelRegistry::RegKernel(const KernelKey desc, kernel::KernelCreator creator) { |
|
|
void KernelRegistry::RegKernel(const KernelKey desc, kernel::KernelCreator creator) { |
|
|
|
|
|
if (creator_arrays_ == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "Creator func array is null."; |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
int index = GetCreatorFuncIndex(desc); |
|
|
int index = GetCreatorFuncIndex(desc); |
|
|
creator_arrays_[index] = creator; |
|
|
creator_arrays_[index] = creator; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType op_type, |
|
|
void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType op_type, |
|
|
kernel::KernelCreator creator) { |
|
|
kernel::KernelCreator creator) { |
|
|
|
|
|
if (creator_arrays_ == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "Creator func array is null."; |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
KernelKey desc = {arch, data_type, op_type}; |
|
|
KernelKey desc = {arch, data_type, op_type}; |
|
|
int index = GetCreatorFuncIndex(desc); |
|
|
int index = GetCreatorFuncIndex(desc); |
|
|
creator_arrays_[index] = creator; |
|
|
creator_arrays_[index] = creator; |
|
|
|