From: @jpc_chenjianping Reviewed-by: @zhang_xue_tong,@zhanghaibo5 Signed-off-by: @zhang_xue_tongpull/16020/MERGE
| @@ -134,6 +134,8 @@ set(LITE_SRC | |||||
| ${LITE_DIR}/src/common/prim_util.cc | ${LITE_DIR}/src/common/prim_util.cc | ||||
| ${LITE_DIR}/src/common/tensor_util.cc | ${LITE_DIR}/src/common/tensor_util.cc | ||||
| ${LITE_DIR}/src/runtime/infer_manager.cc | ${LITE_DIR}/src/runtime/infer_manager.cc | ||||
| ${LITE_DIR}/src/kernel_interface_registry.cc | |||||
| ${LITE_DIR}/src/kernel_registry.cc | |||||
| ${LITE_DIR}/src/lite_model.cc | ${LITE_DIR}/src/lite_model.cc | ||||
| ${LITE_DIR}/src/tensorlist.cc | ${LITE_DIR}/src/tensorlist.cc | ||||
| ${LITE_DIR}/src/tensor.cc | ${LITE_DIR}/src/tensor.cc | ||||
| @@ -23,8 +23,13 @@ RegisterKernelInterface *RegisterKernelInterface::Instance() { | |||||
| return &instance; | return &instance; | ||||
| } | } | ||||
| int RegisterKernelInterface::Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator) { | |||||
| return lite::KernelInterfaceRegistry::Instance()->Reg(vendor, op_type, creator); | |||||
| int RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) { | |||||
| return lite::KernelInterfaceRegistry::Instance()->Reg(provider, op_type, creator); | |||||
| } | |||||
| int RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type, | |||||
| KernelInterfaceCreator creator) { | |||||
| return lite::KernelInterfaceRegistry::Instance()->CustomReg(provider, op_type, creator); | |||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -32,7 +32,7 @@ struct CapabilityParam { | |||||
| class KernelInterface { | class KernelInterface { | ||||
| public: | public: | ||||
| virtual ~KernelInterface() = default; | virtual ~KernelInterface() = default; | ||||
| virtual int Infer(const std::vector<tensor::MSTensor *> &tensor_in, std::vector<tensor::MSTensor *> *outputs, | |||||
| virtual int Infer(const std::vector<tensor::MSTensor *> &inputs, const std::vector<tensor::MSTensor *> &outputs, | |||||
| const schema::Primitive *primitive) { | const schema::Primitive *primitive) { | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| @@ -47,7 +47,8 @@ typedef KernelInterface *(*KernelInterfaceCreator)(); | |||||
| class RegisterKernelInterface { | class RegisterKernelInterface { | ||||
| public: | public: | ||||
| static RegisterKernelInterface *Instance(); | static RegisterKernelInterface *Instance(); | ||||
| int Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator); | |||||
| int CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator); | |||||
| int Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator); | |||||
| virtual ~RegisterKernelInterface() = default; | virtual ~RegisterKernelInterface() = default; | ||||
| private: | private: | ||||
| @@ -56,14 +57,21 @@ class RegisterKernelInterface { | |||||
| class KernelInterfaceReg { | class KernelInterfaceReg { | ||||
| public: | public: | ||||
| KernelInterfaceReg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator) { | |||||
| RegisterKernelInterface::Instance()->Reg(vendor, op_type, creator); | |||||
| KernelInterfaceReg(const std::string &provider, int op_type, KernelInterfaceCreator creator) { | |||||
| RegisterKernelInterface::Instance()->Reg(provider, op_type, creator); | |||||
| } | |||||
| KernelInterfaceReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator) { | |||||
| RegisterKernelInterface::Instance()->CustomReg(provider, op_type, creator); | |||||
| } | } | ||||
| ~KernelInterfaceReg() = default; | ~KernelInterfaceReg() = default; | ||||
| }; | }; | ||||
| #define REGISTER_KERNEL_INTERFACE(vendor, op_type, creator) \ | |||||
| static KernelInterfaceReg g_##vendor##op_type##_inter_reg(vendor, op_type, creator); | |||||
| #define REGISTER_KERNEL_INTERFACE(provider, op_type, creator) \ | |||||
| static KernelInterfaceReg g_##provider##op_type##_inter_reg(provider, op_type, creator); | |||||
| #define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator) \ | |||||
| static KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(provider, op_type, creator); | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,6 +17,8 @@ | |||||
| #include "src/kernel_interface.h" | #include "src/kernel_interface.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "src/common/version_manager.h" | |||||
| #include "schema/model_generated.h" | |||||
| using mindspore::kernel::KernelInterfaceCreator; | using mindspore::kernel::KernelInterfaceCreator; | ||||
| using mindspore::schema::PrimitiveType_MAX; | using mindspore::schema::PrimitiveType_MAX; | ||||
| @@ -24,27 +26,89 @@ using mindspore::schema::PrimitiveType_MIN; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | namespace { | ||||
| static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN + 1; | |||||
| static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN; | |||||
| } | } | ||||
| int KernelInterfaceRegistry::Reg(const std::string &vendor, const int &op_type, KernelInterfaceCreator creator) { | |||||
| auto vendor_hash = std::hash<std::string>{}(vendor); | |||||
| auto iter = kernel_interfaces_.find(vendor_hash); | |||||
| bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node) { | |||||
| if (VersionManager::GetInstance()->GetSchemaVersion() == SCHEMA_V0) { | |||||
| return false; | |||||
| } | |||||
| auto primitive = static_cast<const schema::Primitive *>(node->primitive_); | |||||
| if (primitive == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto op_type = primitive->value_type(); | |||||
| if (op_type == schema::PrimitiveType_Custom) { | |||||
| return std::any_of(custom_interfaces_.begin(), custom_interfaces_.end(), [node](auto &&item) { | |||||
| if (item.second[node->name_] != nullptr) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| }); | |||||
| } | |||||
| return std::any_of(kernel_interfaces_.begin(), kernel_interfaces_.end(), | |||||
| [op_type, &mutex = this->mutex_](auto &&item) { | |||||
| std::unique_lock<std::mutex> lock(mutex); | |||||
| if (item.second[op_type] != nullptr) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| }); | |||||
| } | |||||
| int KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &op_type, | |||||
| KernelInterfaceCreator creator) { | |||||
| custom_interfaces_[provider][op_type] = creator; | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::KernelInterface *KernelInterfaceRegistry::GetKernelInterface(const std::string &provider, int op_type) { | |||||
| if (op_type < PrimitiveType_MIN || op_type > kMaxKernelNum) { | |||||
| MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << kMaxKernelNum; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_lock<std::mutex> lock(mutex_); | |||||
| auto iter = kernel_interfaces_.find(provider); | |||||
| if (iter == kernel_interfaces_.end()) { | if (iter == kernel_interfaces_.end()) { | ||||
| kernel_interfaces_[vendor_hash] = | |||||
| reinterpret_cast<KernelInterfaceCreator *>(malloc(kMaxKernelNum * sizeof(KernelInterfaceCreator))); | |||||
| if (kernel_interfaces_[vendor_hash] == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| auto creator = iter->second[op_type]; | |||||
| if (creator != nullptr) { | |||||
| return creator(); | |||||
| } | } | ||||
| return nullptr; | |||||
| } | |||||
| int KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) { | |||||
| if (op_type < PrimitiveType_MIN || op_type > kMaxKernelNum) { | if (op_type < PrimitiveType_MIN || op_type > kMaxKernelNum) { | ||||
| MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << kMaxKernelNum; | MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << kMaxKernelNum; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| kernel_interfaces_[vendor_hash][op_type] = creator; | |||||
| std::unique_lock<std::mutex> lock(mutex_); | |||||
| auto iter = kernel_interfaces_.find(provider); | |||||
| if (iter == kernel_interfaces_.end()) { | |||||
| kernel_interfaces_[provider] = | |||||
| reinterpret_cast<KernelInterfaceCreator *>(malloc(kMaxKernelNum * sizeof(KernelInterfaceCreator))); | |||||
| if (kernel_interfaces_[provider] == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| kernel_interfaces_[provider][op_type] = creator; | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| KernelInterfaceRegistry::~KernelInterfaceRegistry() { | |||||
| for (auto &&item : kernel_interfaces_) { | |||||
| free(item.second); | |||||
| item.second = nullptr; | |||||
| } | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,8 +18,10 @@ | |||||
| #define MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_ | #define MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_ | ||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | |||||
| #include <map> | |||||
| #include <mutex> | |||||
| #include "src/kernel_interface.h" | #include "src/kernel_interface.h" | ||||
| #include "include/model.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -29,14 +31,21 @@ class KernelInterfaceRegistry { | |||||
| static KernelInterfaceRegistry instance; | static KernelInterfaceRegistry instance; | ||||
| return &instance; | return &instance; | ||||
| } | } | ||||
| int Reg(const std::string &vendor, const int &op_type, kernel::KernelInterfaceCreator creator); | |||||
| virtual ~KernelInterfaceRegistry() = default; | |||||
| bool CheckReg(const lite::Model::Node *node); | |||||
| kernel::KernelInterface *GetKernelInterface(const std::string &provider, int op_type); | |||||
| const std::map<std::string, kernel::KernelInterfaceCreator *> &kernel_interfaces() { return kernel_interfaces_; } | |||||
| int CustomReg(const std::string &provider, const std::string &op_type, kernel::KernelInterfaceCreator creator); | |||||
| int Reg(const std::string &provider, int op_type, kernel::KernelInterfaceCreator creator); | |||||
| virtual ~KernelInterfaceRegistry(); | |||||
| private: | private: | ||||
| KernelInterfaceRegistry() = default; | KernelInterfaceRegistry() = default; | ||||
| std::unordered_map<size_t, kernel::KernelInterfaceCreator *> kernel_interfaces_; | |||||
| std::mutex mutex_; | |||||
| // key: provider | |||||
| std::map<std::string, kernel::KernelInterfaceCreator *> kernel_interfaces_; | |||||
| // key: provider key: custom type | |||||
| std::map<std::string, std::map<std::string, kernel::KernelInterfaceCreator>> custom_interfaces_; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -38,7 +38,7 @@ using mindspore::kernel::KernelKey; | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| namespace { | namespace { | ||||
| static const int kKernelMaxNum = (kNumberTypeEnd - kNumberTypeBegin + 1) * (PrimitiveType_MAX - PrimitiveType_MIN + 1); | |||||
| static const int kKernelMaxNum = (kNumberTypeEnd - kNumberTypeBegin - 1) * (PrimitiveType_MAX - PrimitiveType_MIN); | |||||
| } // namespace | } // namespace | ||||
| KernelRegistry *KernelRegistry::GetInstance() { | KernelRegistry *KernelRegistry::GetInstance() { | ||||
| @@ -56,50 +56,81 @@ KernelRegistry *KernelRegistry::GetInstance() { | |||||
| } | } | ||||
| int KernelRegistry::GetFuncIndex(const kernel::KernelKey &desc) { | int KernelRegistry::GetFuncIndex(const kernel::KernelKey &desc) { | ||||
| int dType_index = static_cast<int>(desc.data_type) - kNumberTypeBegin; | |||||
| return dType_index * op_type_length_ + desc.type; | |||||
| if (desc.data_type >= kNumberTypeEnd) { | |||||
| return -1; | |||||
| } | |||||
| int data_type_index = static_cast<int>(desc.data_type) - kNumberTypeBegin - 1; | |||||
| if (data_type_index < 0) { | |||||
| return -1; | |||||
| } | |||||
| return data_type_index * op_type_length_ + desc.type; | |||||
| } | } | ||||
| int KernelRegistry::RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, | |||||
| const int type, kernel::CreateKernel creator) { | |||||
| auto vendor_hash = std::hash<std::string>{}(vendor); | |||||
| auto arch_hash = std::hash<std::string>{}(arch); | |||||
| auto iter = kernel_creators_.find(vendor_hash); | |||||
| int KernelRegistry::RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type, | |||||
| const std::string &type, CreateKernel creator) { | |||||
| if (data_type >= kNumberTypeEnd) { | |||||
| MS_LOG(ERROR) << "invalid data_type: " << data_type << "!provider: " << provider; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::unique_lock<std::mutex> lock(lock_); | |||||
| auto iter = custom_kernel_creators_.find(provider); | |||||
| if (iter == custom_kernel_creators_.end()) { | |||||
| custom_kernel_creators_[provider][arch] = | |||||
| reinterpret_cast<CreateKernel *>(malloc(data_type_length_ * sizeof(CreateKernel))); | |||||
| if (custom_kernel_creators_[provider][arch] == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc custom kernel creator fail!provider: " << provider << ", arch: " << arch; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(custom_kernel_creators_[provider][arch], 0, data_type_length_ * sizeof(CreateKernel)); | |||||
| } | |||||
| int data_type_index = data_type - kNumberTypeBegin - 1; | |||||
| if (data_type_index < 0) { | |||||
| MS_LOG(ERROR) << "invalid data_type: " << data_type << "!provider: " << provider; | |||||
| return RET_ERROR; | |||||
| } | |||||
| custom_kernel_creators_[provider][arch][data_type_index] = creator; | |||||
| return RET_OK; | |||||
| } | |||||
| int KernelRegistry::RegKernel(const std::string &arch, const std::string &provider, TypeId data_type, int type, | |||||
| kernel::CreateKernel creator) { | |||||
| std::unique_lock<std::mutex> lock(lock_); | |||||
| auto iter = kernel_creators_.find(provider); | |||||
| if (iter == kernel_creators_.end()) { | if (iter == kernel_creators_.end()) { | ||||
| all_vendors_.insert(vendor); | |||||
| kernel_creators_[vendor_hash][arch_hash] = | |||||
| reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel))); | |||||
| if (kernel_creators_[vendor_hash][arch_hash] == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc kernel creator buffer fail! vendor: " << vendor << ",arch:" << arch; | |||||
| kernel_creators_[provider][arch] = reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel))); | |||||
| if (kernel_creators_[provider][arch] == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(kernel_creators_[vendor_hash][arch_hash], 0, kKernelMaxNum * sizeof(CreateKernel)); | |||||
| memset(kernel_creators_[provider][arch], 0, kKernelMaxNum * sizeof(CreateKernel)); | |||||
| } else { | } else { | ||||
| auto iter_arch = iter->second.find(arch_hash); | |||||
| auto iter_arch = iter->second.find(arch); | |||||
| if (iter_arch == iter->second.end()) { | if (iter_arch == iter->second.end()) { | ||||
| iter->second[arch_hash] = reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel))); | |||||
| if (iter->second[arch_hash] == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc kernel creator buffer fail! vendor: " << vendor << ",arch:" << arch; | |||||
| iter->second[arch] = reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel))); | |||||
| if (iter->second[arch] == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(iter->second[arch_hash], 0, kKernelMaxNum * sizeof(CreateKernel)); | |||||
| memset(iter->second[arch], 0, kKernelMaxNum * sizeof(CreateKernel)); | |||||
| } | } | ||||
| } | } | ||||
| KernelKey desc = {kCPU, data_type, type, arch, vendor}; | |||||
| KernelKey desc = {kCPU, data_type, type, arch, provider}; | |||||
| int index = GetFuncIndex(desc); | int index = GetFuncIndex(desc); | ||||
| if (index >= kKernelMaxNum) { | |||||
| if (index >= kKernelMaxNum || index < 0) { | |||||
| MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << data_type << ",op type " << type; | MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << data_type << ",op type " << type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| kernel_creators_[vendor_hash][arch_hash][index] = creator; | |||||
| kernel_creators_[provider][arch][index] = creator; | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int KernelRegistry::Init() { return RET_OK; } | int KernelRegistry::Init() { return RET_OK; } | ||||
| kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { | kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { | ||||
| if (desc.vendor == kBuiltin) { | |||||
| if (desc.provider == kBuiltin) { | |||||
| int index = GetCreatorFuncIndex(desc); | int index = GetCreatorFuncIndex(desc); | ||||
| if (index >= array_size_ || index < 0) { | if (index >= array_size_ || index < 0) { | ||||
| MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type " | MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type " | ||||
| @@ -108,29 +139,29 @@ kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { | |||||
| } | } | ||||
| return creator_arrays_[index]; | return creator_arrays_[index]; | ||||
| } | } | ||||
| MS_LOG(ERROR) << "Call wrong interface!vendor: " << desc.vendor; | |||||
| MS_LOG(ERROR) << "Call wrong interface!provider: " << desc.provider; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| kernel::CreateKernel KernelRegistry::GetDelegateCreator(const kernel::KernelKey &desc) { | kernel::CreateKernel KernelRegistry::GetDelegateCreator(const kernel::KernelKey &desc) { | ||||
| auto vendor_hash = std::hash<std::string>{}(desc.vendor); | |||||
| auto it_by_vendor = kernel_creators_.find(vendor_hash); | |||||
| if (it_by_vendor == kernel_creators_.end()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto arch_hash = std::hash<std::string>{}(desc.kernel_arch); | |||||
| auto it_by_arch = it_by_vendor->second.find(arch_hash); | |||||
| if (it_by_arch == it_by_vendor->second.end()) { | |||||
| return nullptr; | |||||
| } | |||||
| kernel::CreateKernel creator = nullptr; | |||||
| auto index = GetFuncIndex(desc); | auto index = GetFuncIndex(desc); | ||||
| if (index < 0 || index >= kKernelMaxNum) { | |||||
| MS_LOG(ERROR) << "invalid kernel key, arch " << desc.kernel_arch << ", data_type" << desc.data_type << ",op type " | |||||
| << desc.type << ", vendor: " << desc.vendor; | |||||
| if (index >= kKernelMaxNum || index < 0) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return it_by_arch->second[index]; | |||||
| std::unique_lock<std::mutex> lock(lock_); | |||||
| for (auto &&item : kernel_creators_) { | |||||
| for (auto &&arch_item : item.second) { | |||||
| creator = arch_item.second[index]; | |||||
| if (creator != nullptr) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (creator != nullptr) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| return creator; | |||||
| } | } | ||||
| int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) { | int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) { | ||||
| @@ -172,6 +203,19 @@ KernelRegistry::~KernelRegistry() { | |||||
| free(instance->creator_arrays_); | free(instance->creator_arrays_); | ||||
| instance->creator_arrays_ = nullptr; | instance->creator_arrays_ = nullptr; | ||||
| } | } | ||||
| for (auto &&item : kernel_creators_) { | |||||
| for (auto &&creator : item.second) { | |||||
| free(creator.second); | |||||
| creator.second = nullptr; | |||||
| } | |||||
| } | |||||
| for (auto &&item : custom_kernel_creators_) { | |||||
| for (auto &&creator : item.second) { | |||||
| free(creator.second); | |||||
| creator.second = nullptr; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| bool KernelRegistry::SupportKernel(const KernelKey &key) { | bool KernelRegistry::SupportKernel(const KernelKey &key) { | ||||
| @@ -184,7 +228,7 @@ int KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors, const std | |||||
| kernel::LiteKernel **kernel, const void *primitive) { | kernel::LiteKernel **kernel, const void *primitive) { | ||||
| MS_ASSERT(ctx != nullptr); | MS_ASSERT(ctx != nullptr); | ||||
| MS_ASSERT(kernel != nullptr); | MS_ASSERT(kernel != nullptr); | ||||
| if (key.vendor == kBuiltin) { | |||||
| if (key.provider == kBuiltin) { | |||||
| auto creator = GetCreator(key); | auto creator = GetCreator(key); | ||||
| if (creator != nullptr) { | if (creator != nullptr) { | ||||
| *kernel = creator(in_tensors, out_tensors, parameter, ctx, key); | *kernel = creator(in_tensors, out_tensors, parameter, ctx, key); | ||||
| @@ -196,17 +240,18 @@ int KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors, const std | |||||
| } | } | ||||
| } else { | } else { | ||||
| auto creator = GetDelegateCreator(key); | auto creator = GetDelegateCreator(key); | ||||
| if (creator != nullptr) { | |||||
| std::vector<tensor::MSTensor *> tensors_in; | |||||
| Tensor2MSTensor(std::move(in_tensors), &tensors_in); | |||||
| std::vector<tensor::MSTensor *> tensors_out; | |||||
| Tensor2MSTensor(std::move(out_tensors), &tensors_out); | |||||
| *kernel = creator(tensors_in, tensors_out, static_cast<const schema::Primitive *>(primitive), ctx); | |||||
| if (*kernel != nullptr) { | |||||
| return RET_OK; | |||||
| } | |||||
| return RET_ERROR; | |||||
| if (creator == nullptr) { | |||||
| return RET_NOT_SUPPORT; | |||||
| } | } | ||||
| std::vector<tensor::MSTensor *> tensors_in; | |||||
| Tensor2MSTensor(std::move(in_tensors), &tensors_in); | |||||
| std::vector<tensor::MSTensor *> tensors_out; | |||||
| Tensor2MSTensor(std::move(out_tensors), &tensors_out); | |||||
| *kernel = creator(tensors_in, tensors_out, static_cast<const schema::Primitive *>(primitive), ctx); | |||||
| if (*kernel != nullptr) { | |||||
| return RET_OK; | |||||
| } | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| return RET_NOT_SUPPORT; | return RET_NOT_SUPPORT; | ||||
| } | } | ||||
| @@ -18,6 +18,7 @@ | |||||
| #define MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_ | #define MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_ | ||||
| #include <string> | #include <string> | ||||
| #include <map> | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | #include <vector> | ||||
| #include <set> | #include <set> | ||||
| @@ -42,9 +43,14 @@ class KernelRegistry { | |||||
| virtual kernel::CreateKernel GetDelegateCreator(const kernel::KernelKey &desc); | virtual kernel::CreateKernel GetDelegateCreator(const kernel::KernelKey &desc); | ||||
| int GetCreatorFuncIndex(kernel::KernelKey desc); | int GetCreatorFuncIndex(kernel::KernelKey desc); | ||||
| int GetFuncIndex(const kernel::KernelKey &desc); | int GetFuncIndex(const kernel::KernelKey &desc); | ||||
| const std::map<std::string, std::unordered_map<std::string, kernel::CreateKernel *>> &kernel_creators() { | |||||
| return kernel_creators_; | |||||
| } | |||||
| int RegCustomKernel(const std::string &arch, const std::string &vendor, TypeId data_type, const std::string &type, | |||||
| kernel::CreateKernel creator); | |||||
| void RegKernel(kernel::KernelKey desc, kernel::KernelCreator creator); | void RegKernel(kernel::KernelKey desc, kernel::KernelCreator creator); | ||||
| void RegKernel(kernel::KERNEL_ARCH arch, TypeId data_type, int type, kernel::KernelCreator creator); | void RegKernel(kernel::KERNEL_ARCH arch, TypeId data_type, int type, kernel::KernelCreator creator); | ||||
| int RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, const int type, | |||||
| int RegKernel(const std::string &arch, const std::string &vendor, TypeId data_type, int type, | |||||
| kernel::CreateKernel creator); | kernel::CreateKernel creator); | ||||
| bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators); | bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators); | ||||
| bool SupportKernel(const kernel::KernelKey &key); | bool SupportKernel(const kernel::KernelKey &key); | ||||
| @@ -58,8 +64,8 @@ class KernelRegistry { | |||||
| static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1}; | static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1}; | ||||
| static const int array_size_{device_type_length_ * data_type_length_ * op_type_length_}; | static const int array_size_{device_type_length_ * data_type_length_ * op_type_length_}; | ||||
| kernel::KernelCreator *creator_arrays_ = nullptr; | kernel::KernelCreator *creator_arrays_ = nullptr; | ||||
| std::unordered_map<std::size_t, std::unordered_map<std::size_t, kernel::CreateKernel *>> kernel_creators_; | |||||
| std::set<std::string> all_vendors_; | |||||
| std::map<std::string, std::unordered_map<std::string, kernel::CreateKernel *>> kernel_creators_; | |||||
| std::map<std::string, std::unordered_map<std::string, kernel::CreateKernel *>> custom_kernel_creators_; | |||||
| private: | private: | ||||
| std::mutex lock_; | std::mutex lock_; | ||||
| @@ -41,11 +41,11 @@ struct KernelKey { | |||||
| TypeId data_type; | TypeId data_type; | ||||
| int type; | int type; | ||||
| std::string kernel_arch; | std::string kernel_arch; | ||||
| std::string vendor{kBuiltin}; | |||||
| std::string provider{kBuiltin}; | |||||
| bool operator<(const KernelKey &dst) const { | bool operator<(const KernelKey &dst) const { | ||||
| if (vendor != dst.vendor) { | |||||
| return vendor < dst.vendor; | |||||
| if (provider != dst.provider) { | |||||
| return provider < dst.provider; | |||||
| } else if (kernel_arch != dst.kernel_arch) { | } else if (kernel_arch != dst.kernel_arch) { | ||||
| return kernel_arch < dst.kernel_arch; | return kernel_arch < dst.kernel_arch; | ||||
| } else if (arch != dst.arch) { | } else if (arch != dst.arch) { | ||||
| @@ -23,9 +23,14 @@ RegisterKernel *RegisterKernel::GetInstance() { | |||||
| return &instance; | return &instance; | ||||
| } | } | ||||
| int RegisterKernel::RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, | |||||
| const int op_type, CreateKernel creator) { | |||||
| return lite::KernelRegistry::GetInstance()->RegKernel(arch, vendor, data_type, op_type, creator); | |||||
| int RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type, | |||||
| const std::string &type, CreateKernel creator) { | |||||
| return lite::KernelRegistry::GetInstance()->RegCustomKernel(arch, provider, data_type, type, creator); | |||||
| } | |||||
| int RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, TypeId data_type, int op_type, | |||||
| CreateKernel creator) { | |||||
| return lite::KernelRegistry::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator); | |||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,22 +29,30 @@ typedef kernel::LiteKernel *(*CreateKernel)(const std::vector<tensor::MSTensor * | |||||
| class RegisterKernel { | class RegisterKernel { | ||||
| public: | public: | ||||
| static RegisterKernel *GetInstance(); | static RegisterKernel *GetInstance(); | ||||
| int RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, const int type, | |||||
| CreateKernel creator); | |||||
| int RegKernel(const std::string &arch, const std::string &provider, TypeId data_type, int type, CreateKernel creator); | |||||
| int RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type, const std::string &type, | |||||
| CreateKernel creator); | |||||
| }; | }; | ||||
| class KernelReg { | class KernelReg { | ||||
| public: | public: | ||||
| ~KernelReg() = default; | ~KernelReg() = default; | ||||
| KernelReg(const std::string &arch, const std::string &vendor, const TypeId data_type, const int op_type, | |||||
| KernelReg(const std::string &arch, const std::string &provider, TypeId data_type, int op_type, CreateKernel creator) { | |||||
| RegisterKernel::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator); | |||||
| } | |||||
| KernelReg(const std::string &arch, const std::string &provider, TypeId data_type, const std::string &op_type, | |||||
| CreateKernel creator) { | CreateKernel creator) { | ||||
| RegisterKernel::GetInstance()->RegKernel(arch, vendor, data_type, op_type, creator); | |||||
| RegisterKernel::GetInstance()->RegCustomKernel(arch, provider, data_type, op_type, creator); | |||||
| } | } | ||||
| }; | }; | ||||
| #define REGISTER_KERNEL(arch, vendor, data_type, op_type, creator) \ | |||||
| static KernelReg g_##arch##vendor##data_type##op_type##kernelReg(arch, vendor, data_type, op_type, creator); | |||||
| #define REGISTER_KERNEL(arch, provider, data_type, op_type, creator) \ | |||||
| static KernelReg g_##arch##provider##data_type##op_type##kernelReg(arch, provider, data_type, op_type, creator); | |||||
| #define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator) \ | |||||
| static KernelReg g_##arch##provider##data_type##op_type##kernelReg(arch, provider, data_type, op_type, creator); | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,14 +14,43 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/runtime/infer_manager.h" | #include "src/runtime/infer_manager.h" | ||||
| #include <algorithm> | |||||
| #include "src/common/prim_util.h" | |||||
| #include "src/common/tensor_util.h" | #include "src/common/tensor_util.h" | ||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "nnacl/errorcode.h" | #include "nnacl/errorcode.h" | ||||
| #include "src/tensorlist.h" | #include "src/tensorlist.h" | ||||
| #include "src/kernel_interface_registry.h" | |||||
| #include "src/kernel_registry.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| const void *primitive) { | |||||
| std::vector<tensor::MSTensor *> in_tensors; | |||||
| std::copy(inputs.begin(), inputs.end(), std::back_inserter(in_tensors)); | |||||
| std::vector<tensor::MSTensor *> out_tensors; | |||||
| std::copy(outputs.begin(), outputs.end(), std::back_inserter(out_tensors)); | |||||
| int op_type = GetPrimitiveType(primitive); | |||||
| for (auto &&item : KernelInterfaceRegistry::Instance()->kernel_interfaces()) { | |||||
| auto provider = item.first; | |||||
| auto kernel_interface = KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, op_type); | |||||
| if (kernel_interface == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto ret = kernel_interface->Infer(in_tensors, out_tensors, static_cast<const schema::Primitive *>(primitive)); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "provider: " << provider << ", op_type: " << PrimitiveTypeName(GetPrimitiveType(primitive)) | |||||
| << " infer fail!"; | |||||
| return ret; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| return RET_ERROR; | |||||
| } | |||||
| int KernelInferShape(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs, | int KernelInferShape(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs, | ||||
| OpParameter *parameter) { | OpParameter *parameter) { | ||||
| std::vector<TensorC *> in_tensors; | std::vector<TensorC *> in_tensors; | ||||
| @@ -25,8 +25,10 @@ | |||||
| #include "nnacl/infer/infer.h" | #include "nnacl/infer/infer.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int KernelInferShape(const std::vector<lite::Tensor *> &tensors_in, std::vector<lite::Tensor *> *outputs, | |||||
| int KernelInferShape(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs, | |||||
| OpParameter *parameter); | OpParameter *parameter); | ||||
| int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| const void *primitive); | |||||
| class InferManager { | class InferManager { | ||||
| public: | public: | ||||
| static InferManager *GetInstance() { | static InferManager *GetInstance() { | ||||
| @@ -48,6 +48,7 @@ | |||||
| #if defined(ENABLE_ARM) && defined(ENABLE_FP16) | #if defined(ENABLE_ARM) && defined(ENABLE_FP16) | ||||
| #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h" | #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h" | ||||
| #endif | #endif | ||||
| #include "src/kernel_interface_registry.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| using kernel::KERNEL_ARCH::kCPU; | using kernel::KERNEL_ARCH::kCPU; | ||||
| @@ -130,6 +131,10 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node) { | |||||
| std::vector<Tensor *> inputs; | std::vector<Tensor *> inputs; | ||||
| std::vector<Tensor *> outputs; | std::vector<Tensor *> outputs; | ||||
| FindNodeInoutTensors(*node, &inputs, &outputs); | FindNodeInoutTensors(*node, &inputs, &outputs); | ||||
| if (KernelInterfaceRegistry::Instance()->CheckReg(node)) { | |||||
| return KernelInferShape(inputs, outputs, node->primitive_); | |||||
| } | |||||
| int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | ||||
| auto parame_gen = | auto parame_gen = | ||||
| PopulateRegistry::GetInstance()->GetParameterCreator(GetPrimitiveType(node->primitive_), schema_version); | PopulateRegistry::GetInstance()->GetParameterCreator(GetPrimitiveType(node->primitive_), schema_version); | ||||
| @@ -432,6 +437,18 @@ int Scheduler::FindNpuKernel(const std::vector<Tensor *> &in_tensors, const std: | |||||
| return RET_NOT_SUPPORT; | return RET_NOT_SUPPORT; | ||||
| } | } | ||||
| int Scheduler::FindProviderKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | |||||
| const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel) { | |||||
| int ret = RET_ERROR; | |||||
| if (KernelRegistry::GetInstance()->kernel_creators().size() != 0 && | |||||
| VersionManager::GetInstance()->GetSchemaVersion() != SCHEMA_V0) { | |||||
| kernel::KernelKey desc{kCPU, data_type, GetPrimitiveType(node->primitive_), "", ""}; | |||||
| ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, desc, nullptr, kernel, | |||||
| node->primitive_); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in_tensors, | kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in_tensors, | ||||
| const std::vector<Tensor *> &out_tensors, const Model::Node *node, | const std::vector<Tensor *> &out_tensors, const Model::Node *node, | ||||
| TypeId prefer_data_type) { | TypeId prefer_data_type) { | ||||
| @@ -439,14 +456,18 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in | |||||
| // why we need this | // why we need this | ||||
| TypeId data_type = | TypeId data_type = | ||||
| (node->quant_type_ == schema::QuantType_QUANT_WEIGHT) ? kNumberTypeFloat32 : GetFirstFp32Fp16OrInt8Type(in_tensors); | (node->quant_type_ == schema::QuantType_QUANT_WEIGHT) ? kNumberTypeFloat32 : GetFirstFp32Fp16OrInt8Type(in_tensors); | ||||
| kernel::LiteKernel *kernel = nullptr; | |||||
| int status; | |||||
| status = FindProviderKernel(in_tensors, out_tensors, node, data_type, &kernel); | |||||
| if (status == RET_OK && kernel != nullptr) { | |||||
| return kernel; | |||||
| } | |||||
| OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)]; | OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)]; | ||||
| if (op_parameter == nullptr) { | if (op_parameter == nullptr) { | ||||
| MS_LOG(ERROR) << "Can not find OpParameter!type: " << PrimitiveTypeName(GetPrimitiveType(node->primitive_)); | MS_LOG(ERROR) << "Can not find OpParameter!type: " << PrimitiveTypeName(GetPrimitiveType(node->primitive_)); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(op_parameter->type_)}; | kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(op_parameter->type_)}; | ||||
| kernel::LiteKernel *kernel = nullptr; | |||||
| int status; | |||||
| #ifdef SUPPORT_GPU | #ifdef SUPPORT_GPU | ||||
| // if (node->device_type_ == DT_GPU || node->device_type_ == DEFAULT) { | // if (node->device_type_ == DT_GPU || node->device_type_ == DEFAULT) { | ||||
| status = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel); | status = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel); | ||||
| @@ -60,6 +60,7 @@ class Scheduler { | |||||
| kernel::LiteKernel *FindBackendKernel(const std::vector<Tensor *> &in_tensors, | kernel::LiteKernel *FindBackendKernel(const std::vector<Tensor *> &in_tensors, | ||||
| const std::vector<Tensor *> &out_tensors, const Model::Node *node, | const std::vector<Tensor *> &out_tensors, const Model::Node *node, | ||||
| TypeId prefer_data_type = kTypeUnknown); | TypeId prefer_data_type = kTypeUnknown); | ||||
| int FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | int FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | ||||
| OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type, | OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type, | ||||
| kernel::LiteKernel **kernel); | kernel::LiteKernel **kernel); | ||||
| @@ -67,6 +68,9 @@ class Scheduler { | |||||
| OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel); | OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel); | ||||
| int FindNpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | int FindNpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | ||||
| OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel); | OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel); | ||||
| int FindProviderKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | |||||
| const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel); | |||||
| // schedule a partial node to a subgraph_kernel | // schedule a partial node to a subgraph_kernel | ||||
| kernel::LiteKernel *SchedulePartialToKernel(const lite::Model::Node *src_node); | kernel::LiteKernel *SchedulePartialToKernel(const lite::Model::Node *src_node); | ||||
| // schedule a node to a kernel | // schedule a node to a kernel | ||||
| @@ -141,6 +141,8 @@ set(TEST_LITE_SRC | |||||
| ${LITE_DIR}/src/tensorlist.cc | ${LITE_DIR}/src/tensorlist.cc | ||||
| ${LITE_DIR}/src/executor.cc | ${LITE_DIR}/src/executor.cc | ||||
| ${LITE_DIR}/src/inner_context.cc | ${LITE_DIR}/src/inner_context.cc | ||||
| ${LITE_DIR}/src/kernel_interface.cc | |||||
| ${LITE_DIR}/src/kernel_interface_registry.cc | |||||
| ${LITE_DIR}/src/kernel_registry.cc | ${LITE_DIR}/src/kernel_registry.cc | ||||
| ${LITE_DIR}/src/register_kernel.cc | ${LITE_DIR}/src/register_kernel.cc | ||||
| ${LITE_DIR}/src/lite_kernel.cc | ${LITE_DIR}/src/lite_kernel.cc | ||||
| @@ -123,6 +123,7 @@ set(LITE_SRC | |||||
| ${SRC_DIR}/tensor.cc | ${SRC_DIR}/tensor.cc | ||||
| ${SRC_DIR}/ms_tensor.cc | ${SRC_DIR}/ms_tensor.cc | ||||
| ${SRC_DIR}/tensorlist.cc | ${SRC_DIR}/tensorlist.cc | ||||
| ${SRC_DIR}/kernel_interface_registry.cc | |||||
| ${SRC_DIR}/kernel_registry.cc | ${SRC_DIR}/kernel_registry.cc | ||||
| ${SRC_DIR}/register_kernel.cc | ${SRC_DIR}/register_kernel.cc | ||||
| ${SRC_DIR}/lite_kernel.cc | ${SRC_DIR}/lite_kernel.cc | ||||