From 27662712485c83b080f2783f549abe1aef111a6d Mon Sep 17 00:00:00 2001 From: chenjianping Date: Wed, 28 Apr 2021 21:46:19 +0800 Subject: [PATCH] support custom op infershape --- mindspore/lite/micro/cmake/file_list.cmake | 2 + mindspore/lite/src/kernel_interface.cc | 9 +- mindspore/lite/src/kernel_interface.h | 20 ++- .../lite/src/kernel_interface_registry.cc | 86 +++++++++-- .../lite/src/kernel_interface_registry.h | 19 ++- mindspore/lite/src/kernel_registry.cc | 145 ++++++++++++------ mindspore/lite/src/kernel_registry.h | 12 +- mindspore/lite/src/lite_kernel.h | 6 +- mindspore/lite/src/register_kernel.cc | 11 +- mindspore/lite/src/register_kernel.h | 20 ++- mindspore/lite/src/runtime/infer_manager.cc | 29 ++++ mindspore/lite/src/runtime/infer_manager.h | 4 +- mindspore/lite/src/scheduler.cc | 25 ++- mindspore/lite/src/scheduler.h | 4 + mindspore/lite/test/CMakeLists.txt | 2 + mindspore/lite/tools/converter/CMakeLists.txt | 1 + 16 files changed, 303 insertions(+), 92 deletions(-) diff --git a/mindspore/lite/micro/cmake/file_list.cmake b/mindspore/lite/micro/cmake/file_list.cmake index 7d3c51b3a5..934dfe6813 100644 --- a/mindspore/lite/micro/cmake/file_list.cmake +++ b/mindspore/lite/micro/cmake/file_list.cmake @@ -134,6 +134,8 @@ set(LITE_SRC ${LITE_DIR}/src/common/prim_util.cc ${LITE_DIR}/src/common/tensor_util.cc ${LITE_DIR}/src/runtime/infer_manager.cc + ${LITE_DIR}/src/kernel_interface_registry.cc + ${LITE_DIR}/src/kernel_registry.cc ${LITE_DIR}/src/lite_model.cc ${LITE_DIR}/src/tensorlist.cc ${LITE_DIR}/src/tensor.cc diff --git a/mindspore/lite/src/kernel_interface.cc b/mindspore/lite/src/kernel_interface.cc index 80f6e0f01c..9c92cc771b 100644 --- a/mindspore/lite/src/kernel_interface.cc +++ b/mindspore/lite/src/kernel_interface.cc @@ -23,8 +23,13 @@ RegisterKernelInterface *RegisterKernelInterface::Instance() { return &instance; } -int RegisterKernelInterface::Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator) { - return lite::KernelInterfaceRegistry::Instance()->Reg(vendor, op_type, creator); +int RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) { + return lite::KernelInterfaceRegistry::Instance()->Reg(provider, op_type, creator); +} + +int RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type, + KernelInterfaceCreator creator) { + return lite::KernelInterfaceRegistry::Instance()->CustomReg(provider, op_type, creator); } } // namespace kernel } // namespace mindspore diff --git a/mindspore/lite/src/kernel_interface.h b/mindspore/lite/src/kernel_interface.h index 011d008cd2..552f6b09a6 100644 --- a/mindspore/lite/src/kernel_interface.h +++ b/mindspore/lite/src/kernel_interface.h @@ -32,7 +32,7 @@ struct CapabilityParam { class KernelInterface { public: virtual ~KernelInterface() = default; - virtual int Infer(const std::vector &tensor_in, std::vector *outputs, + virtual int Infer(const std::vector &inputs, const std::vector &outputs, const schema::Primitive *primitive) { return 0; } @@ -47,7 +47,8 @@ typedef KernelInterface *(*KernelInterfaceCreator)(); class RegisterKernelInterface { public: static RegisterKernelInterface *Instance(); - int Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator); + int CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator); + int Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator); virtual ~RegisterKernelInterface() = default; private: @@ -56,14 +57,21 @@ class RegisterKernelInterface { class KernelInterfaceReg { public: - KernelInterfaceReg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator) { - RegisterKernelInterface::Instance()->Reg(vendor, op_type, creator); + KernelInterfaceReg(const std::string &provider, int op_type, KernelInterfaceCreator creator) { + RegisterKernelInterface::Instance()->Reg(provider, op_type, creator); + } + + KernelInterfaceReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator) { + RegisterKernelInterface::Instance()->CustomReg(provider, op_type, creator); } ~KernelInterfaceReg() = default; }; -#define REGISTER_KERNEL_INTERFACE(vendor, op_type, creator) \ - static KernelInterfaceReg g_##vendor##op_type##_inter_reg(vendor, op_type, creator); +#define REGISTER_KERNEL_INTERFACE(provider, op_type, creator) \ + static KernelInterfaceReg g_##provider##op_type##_inter_reg(provider, op_type, creator); + +#define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator) \ + static KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(provider, op_type, creator); } // namespace kernel } // namespace mindspore diff --git a/mindspore/lite/src/kernel_interface_registry.cc b/mindspore/lite/src/kernel_interface_registry.cc index 4017841b7d..a14b53f5cd 100644 --- a/mindspore/lite/src/kernel_interface_registry.cc +++ b/mindspore/lite/src/kernel_interface_registry.cc @@ -17,6 +17,8 @@ #include "src/kernel_interface.h" #include "include/errorcode.h" #include "src/common/log_adapter.h" +#include "src/common/version_manager.h" +#include "schema/model_generated.h" using mindspore::kernel::KernelInterfaceCreator; using mindspore::schema::PrimitiveType_MAX; @@ -24,27 +26,89 @@ using mindspore::schema::PrimitiveType_MIN; namespace mindspore { namespace lite { namespace { -static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN + 1; +static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN; } -int KernelInterfaceRegistry::Reg(const std::string &vendor, const int &op_type, KernelInterfaceCreator creator) { - auto vendor_hash = std::hash{}(vendor); - auto iter = kernel_interfaces_.find(vendor_hash); +bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node) { + if (VersionManager::GetInstance()->GetSchemaVersion() == SCHEMA_V0) { + return false; + } + auto primitive = static_cast(node->primitive_); + if (primitive == nullptr) { + return false; + } + + auto op_type = primitive->value_type(); + if (op_type == schema::PrimitiveType_Custom) { + return std::any_of(custom_interfaces_.begin(), custom_interfaces_.end(), [node](auto &&item) { + if (item.second[node->name_] != nullptr) { + return true; + } + return false; + }); + } + + return std::any_of(kernel_interfaces_.begin(), kernel_interfaces_.end(), + [op_type, &mutex = this->mutex_](auto &&item) { + std::unique_lock lock(mutex); + if (item.second[op_type] != nullptr) { + return true; + } + return false; + }); +} + +int KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &op_type, + KernelInterfaceCreator creator) { + custom_interfaces_[provider][op_type] = creator; + return RET_OK; +} + +kernel::KernelInterface *KernelInterfaceRegistry::GetKernelInterface(const std::string &provider, int op_type) { + if (op_type < PrimitiveType_MIN || op_type > kMaxKernelNum) { + MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << kMaxKernelNum; + return nullptr; + } + + std::unique_lock lock(mutex_); + auto iter = kernel_interfaces_.find(provider); if (iter == kernel_interfaces_.end()) { - kernel_interfaces_[vendor_hash] = - reinterpret_cast(malloc(kMaxKernelNum * sizeof(KernelInterfaceCreator))); - if (kernel_interfaces_[vendor_hash] == nullptr) { - MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!"; - return RET_ERROR; - } + return nullptr; + } + + auto creator = iter->second[op_type]; + if (creator != nullptr) { + return creator(); } + return nullptr; +} + +int KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) { if (op_type < PrimitiveType_MIN || op_type > kMaxKernelNum) { MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << kMaxKernelNum; return RET_ERROR; } - kernel_interfaces_[vendor_hash][op_type] = creator; + + std::unique_lock lock(mutex_); + auto iter = kernel_interfaces_.find(provider); + if (iter == kernel_interfaces_.end()) { + kernel_interfaces_[provider] = + reinterpret_cast(malloc(kMaxKernelNum * sizeof(KernelInterfaceCreator))); + if (kernel_interfaces_[provider] == nullptr) { + MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!"; + return RET_ERROR; + } + } + + kernel_interfaces_[provider][op_type] = creator; return RET_OK; } +KernelInterfaceRegistry::~KernelInterfaceRegistry() { + for (auto &&item : kernel_interfaces_) { + free(item.second); + item.second = nullptr; + } +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/kernel_interface_registry.h b/mindspore/lite/src/kernel_interface_registry.h index 33ded2b833..6387d575ef 100644 --- a/mindspore/lite/src/kernel_interface_registry.h +++ b/mindspore/lite/src/kernel_interface_registry.h @@ -18,8 +18,10 @@ #define MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_ #include -#include +#include +#include #include "src/kernel_interface.h" +#include "include/model.h" namespace mindspore { namespace lite { @@ -29,14 +31,21 @@ class KernelInterfaceRegistry { static KernelInterfaceRegistry instance; return &instance; } - - int Reg(const std::string &vendor, const int &op_type, kernel::KernelInterfaceCreator creator); - virtual ~KernelInterfaceRegistry() = default; + bool CheckReg(const lite::Model::Node *node); + kernel::KernelInterface *GetKernelInterface(const std::string &provider, int op_type); + const std::map &kernel_interfaces() { return kernel_interfaces_; } + int CustomReg(const std::string &provider, const std::string &op_type, kernel::KernelInterfaceCreator creator); + int Reg(const std::string &provider, int op_type, kernel::KernelInterfaceCreator creator); + virtual ~KernelInterfaceRegistry(); private: KernelInterfaceRegistry() = default; - std::unordered_map kernel_interfaces_; + std::mutex mutex_; + // key: provider + std::map kernel_interfaces_; + // key: provider key: custom type + std::map> custom_interfaces_; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/kernel_registry.cc b/mindspore/lite/src/kernel_registry.cc index 508bcaefdf..70912ef6a0 100644 --- a/mindspore/lite/src/kernel_registry.cc +++ b/mindspore/lite/src/kernel_registry.cc @@ -38,7 +38,7 @@ using mindspore::kernel::KernelKey; namespace mindspore::lite { namespace { -static const int kKernelMaxNum = (kNumberTypeEnd - kNumberTypeBegin + 1) * (PrimitiveType_MAX - PrimitiveType_MIN + 1); +static const int kKernelMaxNum = (kNumberTypeEnd - kNumberTypeBegin - 1) * (PrimitiveType_MAX - PrimitiveType_MIN); } // namespace KernelRegistry *KernelRegistry::GetInstance() { @@ -56,50 +56,81 @@ KernelRegistry *KernelRegistry::GetInstance() { } int KernelRegistry::GetFuncIndex(const kernel::KernelKey &desc) { - int dType_index = static_cast(desc.data_type) - kNumberTypeBegin; - return dType_index * op_type_length_ + desc.type; + if (desc.data_type >= kNumberTypeEnd) { + return -1; + } + int data_type_index = static_cast(desc.data_type) - kNumberTypeBegin - 1; + if (data_type_index < 0) { + return -1; + } + return data_type_index * op_type_length_ + desc.type; } -int KernelRegistry::RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, - const int type, kernel::CreateKernel creator) { - auto vendor_hash = std::hash{}(vendor); - auto arch_hash = std::hash{}(arch); - auto iter = kernel_creators_.find(vendor_hash); +int KernelRegistry::RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type, + const std::string &type, CreateKernel creator) { + if (data_type >= kNumberTypeEnd) { + MS_LOG(ERROR) << "invalid data_type: " << data_type << "!provider: " << provider; + return RET_ERROR; + } + std::unique_lock lock(lock_); + auto iter = custom_kernel_creators_.find(provider); + if (iter == custom_kernel_creators_.end()) { + custom_kernel_creators_[provider][arch] = + reinterpret_cast(malloc(data_type_length_ * sizeof(CreateKernel))); + if (custom_kernel_creators_[provider][arch] == nullptr) { + MS_LOG(ERROR) << "malloc custom kernel creator fail!provider: " << provider << ", arch: " << arch; + return RET_ERROR; + } + memset(custom_kernel_creators_[provider][arch], 0, data_type_length_ * sizeof(CreateKernel)); + } + + int data_type_index = data_type - kNumberTypeBegin - 1; + if (data_type_index < 0) { + MS_LOG(ERROR) << "invalid data_type: " << data_type << "!provider: " << provider; + return RET_ERROR; + } + custom_kernel_creators_[provider][arch][data_type_index] = creator; + return RET_OK; +} + +int KernelRegistry::RegKernel(const std::string &arch, const std::string &provider, TypeId data_type, int type, + kernel::CreateKernel creator) { + std::unique_lock lock(lock_); + auto iter = kernel_creators_.find(provider); if (iter == kernel_creators_.end()) { - all_vendors_.insert(vendor); - kernel_creators_[vendor_hash][arch_hash] = - reinterpret_cast(malloc(kKernelMaxNum * sizeof(CreateKernel))); - if (kernel_creators_[vendor_hash][arch_hash] == nullptr) { - MS_LOG(ERROR) << "malloc kernel creator buffer fail! vendor: " << vendor << ",arch:" << arch; + kernel_creators_[provider][arch] = reinterpret_cast(malloc(kKernelMaxNum * sizeof(CreateKernel))); + if (kernel_creators_[provider][arch] == nullptr) { + MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch; return RET_ERROR; } - memset(kernel_creators_[vendor_hash][arch_hash], 0, kKernelMaxNum * sizeof(CreateKernel)); + memset(kernel_creators_[provider][arch], 0, kKernelMaxNum * sizeof(CreateKernel)); } else { - auto iter_arch = iter->second.find(arch_hash); + auto iter_arch = iter->second.find(arch); if (iter_arch == iter->second.end()) { - iter->second[arch_hash] = reinterpret_cast(malloc(kKernelMaxNum * sizeof(CreateKernel))); - if (iter->second[arch_hash] == nullptr) { - MS_LOG(ERROR) << "malloc kernel creator buffer fail! vendor: " << vendor << ",arch:" << arch; + iter->second[arch] = reinterpret_cast(malloc(kKernelMaxNum * sizeof(CreateKernel))); + if (iter->second[arch] == nullptr) { + MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch; return RET_ERROR; } - memset(iter->second[arch_hash], 0, kKernelMaxNum * sizeof(CreateKernel)); + memset(iter->second[arch], 0, kKernelMaxNum * sizeof(CreateKernel)); } } - KernelKey desc = {kCPU, data_type, type, arch, vendor}; + KernelKey desc = {kCPU, data_type, type, arch, provider}; int index = GetFuncIndex(desc); - if (index >= kKernelMaxNum) { + if (index >= kKernelMaxNum || index < 0) { MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << data_type << ",op type " << type; return RET_ERROR; } - kernel_creators_[vendor_hash][arch_hash][index] = creator; + + kernel_creators_[provider][arch][index] = creator; return RET_OK; } int KernelRegistry::Init() { return RET_OK; } kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { - if (desc.vendor == kBuiltin) { + if (desc.provider == kBuiltin) { int index = GetCreatorFuncIndex(desc); if (index >= array_size_ || index < 0) { MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type " @@ -108,29 +139,29 @@ kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { } return creator_arrays_[index]; } - MS_LOG(ERROR) << "Call wrong interface!vendor: " << desc.vendor; + MS_LOG(ERROR) << "Call wrong interface!provider: " << desc.provider; return nullptr; } kernel::CreateKernel KernelRegistry::GetDelegateCreator(const kernel::KernelKey &desc) { - auto vendor_hash = std::hash{}(desc.vendor); - auto it_by_vendor = kernel_creators_.find(vendor_hash); - if (it_by_vendor == kernel_creators_.end()) { - return nullptr; - } - auto arch_hash = std::hash{}(desc.kernel_arch); - auto it_by_arch = it_by_vendor->second.find(arch_hash); - if (it_by_arch == it_by_vendor->second.end()) { - return nullptr; - } + kernel::CreateKernel creator = nullptr; auto index = GetFuncIndex(desc); - if (index < 0 || index >= kKernelMaxNum) { - MS_LOG(ERROR) << "invalid kernel key, arch " << desc.kernel_arch << ", data_type" << desc.data_type << ",op type " - << desc.type << ", vendor: " << desc.vendor; + if (index >= kKernelMaxNum || index < 0) { return nullptr; } - - return it_by_arch->second[index]; + std::unique_lock lock(lock_); + for (auto &&item : kernel_creators_) { + for (auto &&arch_item : item.second) { + creator = arch_item.second[index]; + if (creator != nullptr) { + break; + } + } + if (creator != nullptr) { + break; + } + } + return creator; } int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) { @@ -172,6 +203,19 @@ KernelRegistry::~KernelRegistry() { free(instance->creator_arrays_); instance->creator_arrays_ = nullptr; } + + for (auto &&item : kernel_creators_) { + for (auto &&creator : item.second) { + free(creator.second); + creator.second = nullptr; + } + } + for (auto &&item : custom_kernel_creators_) { + for (auto &&creator : item.second) { + free(creator.second); + creator.second = nullptr; + } + } } bool KernelRegistry::SupportKernel(const KernelKey &key) { @@ -184,7 +228,7 @@ int KernelRegistry::GetKernel(const std::vector &in_tensors, const std kernel::LiteKernel **kernel, const void *primitive) { MS_ASSERT(ctx != nullptr); MS_ASSERT(kernel != nullptr); - if (key.vendor == kBuiltin) { + if (key.provider == kBuiltin) { auto creator = GetCreator(key); if (creator != nullptr) { *kernel = creator(in_tensors, out_tensors, parameter, ctx, key); @@ -196,17 +240,18 @@ int KernelRegistry::GetKernel(const std::vector &in_tensors, const std } } else { auto creator = GetDelegateCreator(key); - if (creator != nullptr) { - std::vector tensors_in; - Tensor2MSTensor(std::move(in_tensors), &tensors_in); - std::vector tensors_out; - Tensor2MSTensor(std::move(out_tensors), &tensors_out); - *kernel = creator(tensors_in, tensors_out, static_cast(primitive), ctx); - if (*kernel != nullptr) { - return RET_OK; - } - return RET_ERROR; + if (creator == nullptr) { + return RET_NOT_SUPPORT; } + std::vector tensors_in; + Tensor2MSTensor(std::move(in_tensors), &tensors_in); + std::vector tensors_out; + Tensor2MSTensor(std::move(out_tensors), &tensors_out); + *kernel = creator(tensors_in, tensors_out, static_cast(primitive), ctx); + if (*kernel != nullptr) { + return RET_OK; + } + return RET_ERROR; } return RET_NOT_SUPPORT; } diff --git a/mindspore/lite/src/kernel_registry.h b/mindspore/lite/src/kernel_registry.h index 1d4064bcc0..c553a99e2a 100644 --- a/mindspore/lite/src/kernel_registry.h +++ b/mindspore/lite/src/kernel_registry.h @@ -18,6 +18,7 @@ #define MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_ #include +#include #include #include #include @@ -42,9 +43,14 @@ class KernelRegistry { virtual kernel::CreateKernel GetDelegateCreator(const kernel::KernelKey &desc); int GetCreatorFuncIndex(kernel::KernelKey desc); int GetFuncIndex(const kernel::KernelKey &desc); + const std::map> &kernel_creators() { + return kernel_creators_; + } + int RegCustomKernel(const std::string &arch, const std::string &vendor, TypeId data_type, const std::string &type, + kernel::CreateKernel creator); void RegKernel(kernel::KernelKey desc, kernel::KernelCreator creator); void RegKernel(kernel::KERNEL_ARCH arch, TypeId data_type, int type, kernel::KernelCreator creator); - int RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, const int type, + int RegKernel(const std::string &arch, const std::string &vendor, TypeId data_type, int type, kernel::CreateKernel creator); bool Merge(const std::unordered_map &newCreators); bool SupportKernel(const kernel::KernelKey &key); @@ -58,8 +64,8 @@ class KernelRegistry { static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1}; static const int array_size_{device_type_length_ * data_type_length_ * op_type_length_}; kernel::KernelCreator *creator_arrays_ = nullptr; - std::unordered_map> kernel_creators_; - std::set all_vendors_; + std::map> kernel_creators_; + std::map> custom_kernel_creators_; private: std::mutex lock_; diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index 92ade42615..5f93bcb349 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -41,11 +41,11 @@ struct KernelKey { TypeId data_type; int type; std::string kernel_arch; - std::string vendor{kBuiltin}; + std::string provider{kBuiltin}; bool operator<(const KernelKey &dst) const { - if (vendor != dst.vendor) { - return vendor < dst.vendor; + if (provider != dst.provider) { + return provider < dst.provider; } else if (kernel_arch != dst.kernel_arch) { return kernel_arch < dst.kernel_arch; } else if (arch != dst.arch) { diff --git a/mindspore/lite/src/register_kernel.cc b/mindspore/lite/src/register_kernel.cc index 5f4a07229a..97f331ae48 100644 --- a/mindspore/lite/src/register_kernel.cc +++ b/mindspore/lite/src/register_kernel.cc @@ -23,9 +23,14 @@ RegisterKernel *RegisterKernel::GetInstance() { return &instance; } -int RegisterKernel::RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, - const int op_type, CreateKernel creator) { - return lite::KernelRegistry::GetInstance()->RegKernel(arch, vendor, data_type, op_type, creator); +int RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type, + const std::string &type, CreateKernel creator) { + return lite::KernelRegistry::GetInstance()->RegCustomKernel(arch, provider, data_type, type, creator); +} + +int RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, TypeId data_type, int op_type, + CreateKernel creator) { + return lite::KernelRegistry::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator); } } // namespace kernel } // namespace mindspore diff --git a/mindspore/lite/src/register_kernel.h b/mindspore/lite/src/register_kernel.h index 8f2843c088..b40d3b8bde 100644 --- a/mindspore/lite/src/register_kernel.h +++ b/mindspore/lite/src/register_kernel.h @@ -29,22 +29,30 @@ typedef kernel::LiteKernel *(*CreateKernel)(const std::vectorRegKernel(arch, provider, data_type, op_type, creator); + } + + KernelReg(const std::string &arch, const std::string &provider, TypeId data_type, const std::string &op_type, CreateKernel creator) { - RegisterKernel::GetInstance()->RegKernel(arch, vendor, data_type, op_type, creator); + RegisterKernel::GetInstance()->RegCustomKernel(arch, provider, data_type, op_type, creator); } }; -#define REGISTER_KERNEL(arch, vendor, data_type, op_type, creator) \ - static KernelReg g_##arch##vendor##data_type##op_type##kernelReg(arch, vendor, data_type, op_type, creator); +#define REGISTER_KERNEL(arch, provider, data_type, op_type, creator) \ + static KernelReg g_##arch##provider##data_type##op_type##kernelReg(arch, provider, data_type, op_type, creator); + +#define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator) \ + static KernelReg g_##arch##provider##data_type##op_type##kernelReg(arch, provider, data_type, op_type, creator); } // namespace kernel } // namespace mindspore diff --git a/mindspore/lite/src/runtime/infer_manager.cc b/mindspore/lite/src/runtime/infer_manager.cc index 12a7783872..0b916da758 100644 --- a/mindspore/lite/src/runtime/infer_manager.cc +++ b/mindspore/lite/src/runtime/infer_manager.cc @@ -14,14 +14,43 @@ * limitations under the License. */ #include "src/runtime/infer_manager.h" +#include +#include "src/common/prim_util.h" #include "src/common/tensor_util.h" #include "schema/model_generated.h" #include "include/errorcode.h" #include "nnacl/errorcode.h" #include "src/tensorlist.h" +#include "src/kernel_interface_registry.h" +#include "src/kernel_registry.h" namespace mindspore { namespace lite { +int KernelInferShape(const std::vector &inputs, const std::vector &outputs, + const void *primitive) { + std::vector in_tensors; + std::copy(inputs.begin(), inputs.end(), std::back_inserter(in_tensors)); + std::vector out_tensors; + std::copy(outputs.begin(), outputs.end(), std::back_inserter(out_tensors)); + int op_type = GetPrimitiveType(primitive); + for (auto &&item : KernelInterfaceRegistry::Instance()->kernel_interfaces()) { + auto provider = item.first; + auto kernel_interface = KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, op_type); + if (kernel_interface == nullptr) { + continue; + } + auto ret = kernel_interface->Infer(in_tensors, out_tensors, static_cast(primitive)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "provider: " << provider << ", op_type: " << PrimitiveTypeName(GetPrimitiveType(primitive)) + << " infer fail!"; + return ret; + } + return RET_OK; + } + + return RET_ERROR; +} + int KernelInferShape(const std::vector &inputs, std::vector *outputs, OpParameter *parameter) { std::vector in_tensors; diff --git a/mindspore/lite/src/runtime/infer_manager.h b/mindspore/lite/src/runtime/infer_manager.h index ff5ce5426a..c96c282a7e 100644 --- a/mindspore/lite/src/runtime/infer_manager.h +++ b/mindspore/lite/src/runtime/infer_manager.h @@ -25,8 +25,10 @@ #include "nnacl/infer/infer.h" namespace mindspore::lite { -int KernelInferShape(const std::vector &tensors_in, std::vector *outputs, +int KernelInferShape(const std::vector &inputs, std::vector *outputs, OpParameter *parameter); +int KernelInferShape(const std::vector &inputs, const std::vector &outputs, + const void *primitive); class InferManager { public: static InferManager *GetInstance() { diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 7f7b624d84..8977858d88 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -48,6 +48,7 @@ #if defined(ENABLE_ARM) && defined(ENABLE_FP16) #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h" #endif +#include "src/kernel_interface_registry.h" namespace mindspore::lite { using kernel::KERNEL_ARCH::kCPU; @@ -130,6 +131,10 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node) { std::vector inputs; std::vector outputs; FindNodeInoutTensors(*node, &inputs, &outputs); + if (KernelInterfaceRegistry::Instance()->CheckReg(node)) { + return KernelInferShape(inputs, outputs, node->primitive_); + } + int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); auto parame_gen = PopulateRegistry::GetInstance()->GetParameterCreator(GetPrimitiveType(node->primitive_), schema_version); @@ -432,6 +437,18 @@ int Scheduler::FindNpuKernel(const std::vector &in_tensors, const std: return RET_NOT_SUPPORT; } +int Scheduler::FindProviderKernel(const std::vector &in_tensors, const std::vector &out_tensors, + const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel) { + int ret = RET_ERROR; + if (KernelRegistry::GetInstance()->kernel_creators().size() != 0 && + VersionManager::GetInstance()->GetSchemaVersion() != SCHEMA_V0) { + kernel::KernelKey desc{kCPU, data_type, GetPrimitiveType(node->primitive_), "", ""}; + ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, desc, nullptr, kernel, + node->primitive_); + } + return ret; +} + kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in_tensors, const std::vector &out_tensors, const Model::Node *node, TypeId prefer_data_type) { @@ -439,14 +456,18 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in // why we need this TypeId data_type = (node->quant_type_ == schema::QuantType_QUANT_WEIGHT) ? kNumberTypeFloat32 : GetFirstFp32Fp16OrInt8Type(in_tensors); + kernel::LiteKernel *kernel = nullptr; + int status; + status = FindProviderKernel(in_tensors, out_tensors, node, data_type, &kernel); + if (status == RET_OK && kernel != nullptr) { + return kernel; + } OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)]; if (op_parameter == nullptr) { MS_LOG(ERROR) << "Can not find OpParameter!type: " << PrimitiveTypeName(GetPrimitiveType(node->primitive_)); return nullptr; } kernel::KernelKey desc{kCPU, data_type, static_cast(op_parameter->type_)}; - kernel::LiteKernel *kernel = nullptr; - int status; #ifdef SUPPORT_GPU // if (node->device_type_ == DT_GPU || node->device_type_ == DEFAULT) { status = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel); diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h index e2e927128f..835194cb33 100644 --- a/mindspore/lite/src/scheduler.h +++ b/mindspore/lite/src/scheduler.h @@ -60,6 +60,7 @@ class Scheduler { kernel::LiteKernel *FindBackendKernel(const std::vector &in_tensors, const std::vector &out_tensors, const Model::Node *node, TypeId prefer_data_type = kTypeUnknown); + int FindCpuKernel(const std::vector &in_tensors, const std::vector &out_tensors, OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type, kernel::LiteKernel **kernel); @@ -67,6 +68,9 @@ class Scheduler { OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel); int FindNpuKernel(const std::vector &in_tensors, const std::vector &out_tensors, OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel); + + int FindProviderKernel(const std::vector &in_tensors, const std::vector &out_tensors, + const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel); // schedule a partial node to a subgraph_kernel kernel::LiteKernel *SchedulePartialToKernel(const lite::Model::Node *src_node); // schedule a node to a kernel diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 5649cd6b97..3485ae1791 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -141,6 +141,8 @@ set(TEST_LITE_SRC ${LITE_DIR}/src/tensorlist.cc ${LITE_DIR}/src/executor.cc ${LITE_DIR}/src/inner_context.cc + ${LITE_DIR}/src/kernel_interface.cc + ${LITE_DIR}/src/kernel_interface_registry.cc ${LITE_DIR}/src/kernel_registry.cc ${LITE_DIR}/src/register_kernel.cc ${LITE_DIR}/src/lite_kernel.cc diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 00652fffed..de432bfa92 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -119,6 +119,7 @@ set(LITE_SRC ${SRC_DIR}/tensor.cc ${SRC_DIR}/ms_tensor.cc ${SRC_DIR}/tensorlist.cc + ${SRC_DIR}/kernel_interface_registry.cc ${SRC_DIR}/kernel_registry.cc ${SRC_DIR}/register_kernel.cc ${SRC_DIR}/lite_kernel.cc