From: @jpc_chenjianping Reviewed-by: @zhang_xue_tong,@zhanghaibo5 Signed-off-by: @zhang_xue_tongpull/15407/MERGE
| @@ -60,6 +60,9 @@ set(LITE_SRC | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc | ${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_model.cc | ${CMAKE_CURRENT_SOURCE_DIR}/lite_model.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc | ${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/register_kernel.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/kernel_interface.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/kernel_interface_registry.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc | ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel_util.cc | ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel_util.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc | ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "src/common/tensor_util.h" | #include "src/common/tensor_util.h" | ||||
| #include <algorithm> | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| @@ -226,5 +227,8 @@ int CheckTensorsInvalid(const std::vector<Tensor *> &tensors) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void Tensor2MSTensor(const std::vector<Tensor *> &&tensors, std::vector<tensor::MSTensor *> *out_tensors) { | |||||
| std::copy(tensors.begin(), tensors.end(), std::back_inserter(*out_tensors)); | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -40,6 +40,7 @@ int GenerateOutTensorC(const OpParameter *const parameter, const std::vector<lit | |||||
| std::vector<lite::Tensor *> *outputs, std::vector<TensorC *> *out_tensor_c); | std::vector<lite::Tensor *> *outputs, std::vector<TensorC *> *out_tensor_c); | ||||
| int CheckTensorsInvalid(const std::vector<Tensor *> &tensors); | int CheckTensorsInvalid(const std::vector<Tensor *> &tensors); | ||||
| void Tensor2MSTensor(const std::vector<Tensor *> &&tensors, std::vector<tensor::MSTensor *> *out_tensors); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,30 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/kernel_interface.h" | |||||
| #include "src/kernel_interface_registry.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| RegisterKernelInterface *RegisterKernelInterface::Instance() { | |||||
| static RegisterKernelInterface instance; | |||||
| return &instance; | |||||
| } | |||||
| int RegisterKernelInterface::Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator) { | |||||
| return lite::KernelInterfaceRegistry::Instance()->Reg(vendor, op_type, creator); | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,69 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_ | |||||
| #define MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "include/ms_tensor.h" | |||||
| #include "schema/model_generated.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| struct CapabilityParam { | |||||
| float exec_time_; | |||||
| float power_usage_; | |||||
| }; | |||||
| class KernelInterface { | |||||
| public: | |||||
| virtual ~KernelInterface() = default; | |||||
| virtual int Infer(const std::vector<tensor::MSTensor *> &tensor_in, std::vector<tensor::MSTensor *> *outputs, | |||||
| const schema::Primitive *primitive) { | |||||
| return 0; | |||||
| } | |||||
| virtual int GetCapability(const std::vector<tensor::MSTensor *> &tensor_in, const schema::Primitive *primitive, | |||||
| CapabilityParam *param) { | |||||
| return 0; | |||||
| } | |||||
| }; | |||||
| typedef KernelInterface *(*KernelInterfaceCreator)(); | |||||
| class RegisterKernelInterface { | |||||
| public: | |||||
| static RegisterKernelInterface *Instance(); | |||||
| int Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator); | |||||
| private: | |||||
| RegisterKernelInterface() = default; | |||||
| }; | |||||
| class KernelInterfaceReg { | |||||
| public: | |||||
| KernelInterfaceReg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator) { | |||||
| RegisterKernelInterface::Instance()->Reg(vendor, op_type, creator); | |||||
| } | |||||
| ~KernelInterfaceReg() = default; | |||||
| }; | |||||
| #define REGISTER_KERNEL_INTERFACE(vendor, op_type, creator) \ | |||||
| static KernelInterfaceReg g_##vendor##op_type##_inter_reg(vendor, op_type, creator); | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_ | |||||
| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/kernel_interface_registry.h" | |||||
| #include "src/kernel_interface.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/common/log_adapter.h" | |||||
| using mindspore::kernel::KernelInterfaceCreator; | |||||
| using mindspore::schema::PrimitiveType_MAX; | |||||
| using mindspore::schema::PrimitiveType_MIN; | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace { | |||||
| static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN + 1; | |||||
| } | |||||
| int KernelInterfaceRegistry::Reg(const std::string &vendor, const int &op_type, KernelInterfaceCreator creator) { | |||||
| auto vendor_hash = std::hash<std::string>{}(vendor); | |||||
| auto iter = kernel_interfaces_.find(vendor_hash); | |||||
| if (iter == kernel_interfaces_.end()) { | |||||
| kernel_interfaces_[vendor_hash] = | |||||
| reinterpret_cast<KernelInterfaceCreator *>(malloc(kMaxKernelNum * sizeof(KernelInterfaceCreator))); | |||||
| if (kernel_interfaces_[vendor_hash] == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| if (op_type < PrimitiveType_MIN || op_type > kMaxKernelNum) { | |||||
| MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << kMaxKernelNum; | |||||
| return RET_ERROR; | |||||
| } | |||||
| kernel_interfaces_[vendor_hash][op_type] = creator; | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_ | |||||
| #define MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_ | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include "src/kernel_interface.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class KernelInterfaceRegistry { | |||||
| public: | |||||
| static KernelInterfaceRegistry *Instance() { | |||||
| static KernelInterfaceRegistry instance; | |||||
| return &instance; | |||||
| } | |||||
| int Reg(const std::string &vendor, const int &op_type, kernel::KernelInterfaceCreator creator); | |||||
| private: | |||||
| KernelInterfaceRegistry() = default; | |||||
| std::unordered_map<size_t, kernel::KernelInterfaceCreator *> kernel_interfaces_; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_ | |||||
| @@ -14,6 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include <utility> | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| #include "src/common/version_manager.h" | #include "src/common/version_manager.h" | ||||
| @@ -26,13 +27,20 @@ | |||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #endif | #endif | ||||
| #include "src/common/tensor_util.h" | |||||
| using mindspore::kernel::CreateKernel; | |||||
| using mindspore::kernel::kBuiltin; | |||||
| using mindspore::kernel::kCPU; | using mindspore::kernel::kCPU; | ||||
| using mindspore::kernel::KERNEL_ARCH; | using mindspore::kernel::KERNEL_ARCH; | ||||
| using mindspore::kernel::KernelCreator; | using mindspore::kernel::KernelCreator; | ||||
| using mindspore::kernel::KernelKey; | using mindspore::kernel::KernelKey; | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| namespace { | |||||
| static const int kKernelMaxNum = (kNumberTypeEnd - kNumberTypeBegin + 1) * (PrimitiveType_MAX - PrimitiveType_MIN + 1); | |||||
| } // namespace | |||||
| KernelRegistry *KernelRegistry::GetInstance() { | KernelRegistry *KernelRegistry::GetInstance() { | ||||
| static KernelRegistry instance; | static KernelRegistry instance; | ||||
| @@ -47,6 +55,47 @@ KernelRegistry *KernelRegistry::GetInstance() { | |||||
| return &instance; | return &instance; | ||||
| } | } | ||||
| int KernelRegistry::GetFuncIndex(const kernel::KernelKey &desc) { | |||||
| int dType_index = static_cast<int>(desc.data_type) - kNumberTypeBegin; | |||||
| return dType_index * op_type_length_ + desc.type; | |||||
| } | |||||
| int KernelRegistry::RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, | |||||
| const int type, kernel::CreateKernel creator) { | |||||
| auto vendor_hash = std::hash<std::string>{}(vendor); | |||||
| auto arch_hash = std::hash<std::string>{}(arch); | |||||
| auto iter = kernel_creators_.find(vendor_hash); | |||||
| if (iter == kernel_creators_.end()) { | |||||
| all_vendors_.insert(vendor); | |||||
| kernel_creators_[vendor_hash][arch_hash] = | |||||
| reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel))); | |||||
| if (kernel_creators_[vendor_hash][arch_hash] == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc kernel creator buffer fail! vendor: " << vendor << ",arch:" << arch; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(kernel_creators_[vendor_hash][arch_hash], 0, kKernelMaxNum * sizeof(CreateKernel)); | |||||
| } else { | |||||
| auto iter_arch = iter->second.find(arch_hash); | |||||
| if (iter_arch == iter->second.end()) { | |||||
| iter->second[arch_hash] = reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel))); | |||||
| if (iter->second[arch_hash] == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc kernel creator buffer fail! vendor: " << vendor << ",arch:" << arch; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(iter->second[arch_hash], 0, kKernelMaxNum * sizeof(CreateKernel)); | |||||
| } | |||||
| } | |||||
| KernelKey desc = {kCPU, data_type, type, arch, vendor}; | |||||
| int index = GetFuncIndex(desc); | |||||
| if (index >= kKernelMaxNum) { | |||||
| MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << data_type << ",op type " << type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| kernel_creators_[vendor_hash][arch_hash][index] = creator; | |||||
| return RET_OK; | |||||
| } | |||||
| int KernelRegistry::Init() { | int KernelRegistry::Init() { | ||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| if (mindspore::lite::IsSupportSDot()) { | if (mindspore::lite::IsSupportSDot()) { | ||||
| @@ -66,17 +115,38 @@ int KernelRegistry::Init() { | |||||
| } | } | ||||
| kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { | kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { | ||||
| int index = GetCreatorFuncIndex(desc); | |||||
| if (index >= array_size_ || index < 0) { | |||||
| MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type " | |||||
| << desc.type; | |||||
| if (desc.vendor == kBuiltin) { | |||||
| int index = GetCreatorFuncIndex(desc); | |||||
| if (index >= array_size_ || index < 0) { | |||||
| MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type " | |||||
| << desc.type; | |||||
| return nullptr; | |||||
| } | |||||
| return creator_arrays_[index]; | |||||
| } | |||||
| MS_LOG(ERROR) << "Call wrong interface!vendor: " << desc.vendor; | |||||
| return nullptr; | |||||
| } | |||||
| kernel::CreateKernel KernelRegistry::GetDelegateCreator(const kernel::KernelKey &desc) { | |||||
| auto vendor_hash = std::hash<std::string>{}(desc.vendor); | |||||
| auto it_by_vendor = kernel_creators_.find(vendor_hash); | |||||
| if (it_by_vendor == kernel_creators_.end()) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto it = creator_arrays_[index]; | |||||
| if (it != nullptr) { | |||||
| return it; | |||||
| auto arch_hash = std::hash<std::string>{}(desc.kernel_arch); | |||||
| auto it_by_arch = it_by_vendor->second.find(arch_hash); | |||||
| if (it_by_arch == it_by_vendor->second.end()) { | |||||
| return nullptr; | |||||
| } | } | ||||
| return nullptr; | |||||
| auto index = GetFuncIndex(desc); | |||||
| if (index < 0 || index >= kKernelMaxNum) { | |||||
| MS_LOG(ERROR) << "invalid kernel key, arch " << desc.kernel_arch << ", data_type" << desc.data_type << ",op type " | |||||
| << desc.type << ", vendor: " << desc.vendor; | |||||
| return nullptr; | |||||
| } | |||||
| return it_by_arch->second[index]; | |||||
| } | } | ||||
| int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) { | int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) { | ||||
| @@ -127,15 +197,28 @@ bool KernelRegistry::SupportKernel(const KernelKey &key) { | |||||
| kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors, | kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors, | ||||
| const std::vector<Tensor *> &out_tensors, const InnerContext *ctx, | const std::vector<Tensor *> &out_tensors, const InnerContext *ctx, | ||||
| const kernel::KernelKey &key, OpParameter *parameter) { | |||||
| const kernel::KernelKey &key, OpParameter *parameter, | |||||
| const void *primitive) { | |||||
| MS_ASSERT(ctx != nullptr); | MS_ASSERT(ctx != nullptr); | ||||
| auto creator = GetCreator(key); | |||||
| if (creator != nullptr) { | |||||
| auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key); | |||||
| if (kernel != nullptr) { | |||||
| kernel->set_desc(key); | |||||
| return kernel; | |||||
| if (key.vendor == kBuiltin) { | |||||
| auto creator = GetCreator(key); | |||||
| if (creator != nullptr) { | |||||
| auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key); | |||||
| if (kernel != nullptr) { | |||||
| kernel->set_desc(key); | |||||
| return kernel; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| auto creator = GetDelegateCreator(key); | |||||
| if (creator == nullptr) { | |||||
| return nullptr; | |||||
| } | } | ||||
| std::vector<tensor::MSTensor *> tensors_in; | |||||
| Tensor2MSTensor(std::move(in_tensors), &tensors_in); | |||||
| std::vector<tensor::MSTensor *> tensors_out; | |||||
| Tensor2MSTensor(std::move(out_tensors), &tensors_out); | |||||
| return creator(tensors_in, tensors_out, static_cast<const schema::Primitive *>(primitive), ctx); | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -20,7 +20,9 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | #include <vector> | ||||
| #include <set> | |||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "src/register_kernel.h" | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| using mindspore::kernel::kKernelArch_MAX; | using mindspore::kernel::kKernelArch_MAX; | ||||
| @@ -37,16 +39,18 @@ class KernelRegistry { | |||||
| static KernelRegistry *GetInstance(); | static KernelRegistry *GetInstance(); | ||||
| static int Init(); | static int Init(); | ||||
| virtual kernel::KernelCreator GetCreator(const kernel::KernelKey &desc); | virtual kernel::KernelCreator GetCreator(const kernel::KernelKey &desc); | ||||
| virtual kernel::CreateKernel GetDelegateCreator(const kernel::KernelKey &desc); | |||||
| int GetCreatorFuncIndex(kernel::KernelKey desc); | int GetCreatorFuncIndex(kernel::KernelKey desc); | ||||
| int GetFuncIndex(const kernel::KernelKey &desc); | |||||
| void RegKernel(kernel::KernelKey desc, kernel::KernelCreator creator); | void RegKernel(kernel::KernelKey desc, kernel::KernelCreator creator); | ||||
| void RegKernel(kernel::KERNEL_ARCH arch, TypeId data_type, int type, kernel::KernelCreator creator); | void RegKernel(kernel::KERNEL_ARCH arch, TypeId data_type, int type, kernel::KernelCreator creator); | ||||
| int RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, const int type, | |||||
| kernel::CreateKernel creator); | |||||
| bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators); | bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators); | ||||
| int GetKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | |||||
| const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *op_parameter, | |||||
| kernel::LiteKernel **kernel); | |||||
| bool SupportKernel(const kernel::KernelKey &key); | bool SupportKernel(const kernel::KernelKey &key); | ||||
| kernel::LiteKernel *GetKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | kernel::LiteKernel *GetKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | ||||
| const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *op_parameter); | |||||
| const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *op_parameter, | |||||
| const void *primitive = nullptr); | |||||
| protected: | protected: | ||||
| static const int device_type_length_{kKernelArch_MAX - kKernelArch_MIN + 1}; | static const int device_type_length_{kKernelArch_MAX - kKernelArch_MIN + 1}; | ||||
| @@ -54,6 +58,8 @@ class KernelRegistry { | |||||
| static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1}; | static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1}; | ||||
| static const int array_size_{device_type_length_ * data_type_length_ * op_type_length_}; | static const int array_size_{device_type_length_ * data_type_length_ * op_type_length_}; | ||||
| kernel::KernelCreator *creator_arrays_ = nullptr; | kernel::KernelCreator *creator_arrays_ = nullptr; | ||||
| std::unordered_map<std::size_t, std::unordered_map<std::size_t, kernel::CreateKernel *>> kernel_creators_; | |||||
| std::set<std::string> all_vendors_; | |||||
| private: | private: | ||||
| std::mutex lock_; | std::mutex lock_; | ||||
| @@ -34,14 +34,21 @@ | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| enum KERNEL_ARCH { kCPU, kGPU, kAPU, kNPU, kKernelArch_MIN = kCPU, kKernelArch_MAX = kNPU }; | enum KERNEL_ARCH { kCPU, kGPU, kAPU, kNPU, kKernelArch_MIN = kCPU, kKernelArch_MAX = kNPU }; | ||||
| static const char *const kBuiltin = "Builtin"; | |||||
| struct KernelKey { | struct KernelKey { | ||||
| KERNEL_ARCH arch; | KERNEL_ARCH arch; | ||||
| TypeId data_type; | TypeId data_type; | ||||
| int type; | int type; | ||||
| std::string kernel_arch; | |||||
| std::string vendor{kBuiltin}; | |||||
| bool operator<(const KernelKey &dst) const { | bool operator<(const KernelKey &dst) const { | ||||
| if (arch != dst.arch) { | |||||
| if (vendor != dst.vendor) { | |||||
| return vendor < dst.vendor; | |||||
| } else if (kernel_arch != dst.kernel_arch) { | |||||
| return kernel_arch < dst.kernel_arch; | |||||
| } else if (arch != dst.arch) { | |||||
| return arch < dst.arch; | return arch < dst.arch; | ||||
| } else if (data_type != dst.data_type) { | } else if (data_type != dst.data_type) { | ||||
| return data_type < dst.data_type; | return data_type < dst.data_type; | ||||
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/register_kernel.h" | |||||
| #include "src/kernel_registry.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| RegisterKernel *RegisterKernel::GetInstance() { | |||||
| static RegisterKernel instance; | |||||
| return &instance; | |||||
| } | |||||
| int RegisterKernel::RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, | |||||
| const int op_type, CreateKernel creator) { | |||||
| return lite::KernelRegistry::GetInstance()->RegKernel(arch, vendor, data_type, op_type, creator); | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_REGISTER_KERNEL_H_ | |||||
| #define MINDSPORE_LITE_SRC_REGISTER_KERNEL_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| typedef kernel::LiteKernel *(*CreateKernel)(const std::vector<tensor::MSTensor *> &inputs, | |||||
| const std::vector<tensor::MSTensor *> &outputs, | |||||
| const schema::Primitive *primitive, const lite::Context *ctx); | |||||
| class RegisterKernel { | |||||
| public: | |||||
| static RegisterKernel *GetInstance(); | |||||
| int RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, const int type, | |||||
| CreateKernel creator); | |||||
| }; | |||||
| class KernelReg { | |||||
| public: | |||||
| ~KernelReg() = default; | |||||
| KernelReg(const std::string &arch, const std::string &vendor, const TypeId data_type, const int op_type, | |||||
| CreateKernel creator) { | |||||
| RegisterKernel::GetInstance()->RegKernel(arch, vendor, data_type, op_type, creator); | |||||
| } | |||||
| }; | |||||
| #define REGISTER_KERNEL(arch, vendor, data_type, op_type, creator) \ | |||||
| static KernelReg g_##arch##vendor##data_type##op_type##kernelReg(arch, vendor, data_type, op_type, creator); | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_REGISTER_KERNEL_H_ | |||||
| @@ -142,6 +142,7 @@ set(TEST_LITE_SRC | |||||
| ${LITE_DIR}/src/executor.cc | ${LITE_DIR}/src/executor.cc | ||||
| ${LITE_DIR}/src/inner_context.cc | ${LITE_DIR}/src/inner_context.cc | ||||
| ${LITE_DIR}/src/kernel_registry.cc | ${LITE_DIR}/src/kernel_registry.cc | ||||
| ${LITE_DIR}/src/register_kernel.cc | |||||
| ${LITE_DIR}/src/lite_kernel.cc | ${LITE_DIR}/src/lite_kernel.cc | ||||
| ${LITE_DIR}/src/lite_kernel_util.cc | ${LITE_DIR}/src/lite_kernel_util.cc | ||||
| ${LITE_DIR}/src/lite_session.cc | ${LITE_DIR}/src/lite_session.cc | ||||
| @@ -111,6 +111,7 @@ set(LITE_SRC | |||||
| ${SRC_DIR}/ms_tensor.cc | ${SRC_DIR}/ms_tensor.cc | ||||
| ${SRC_DIR}/tensorlist.cc | ${SRC_DIR}/tensorlist.cc | ||||
| ${SRC_DIR}/kernel_registry.cc | ${SRC_DIR}/kernel_registry.cc | ||||
| ${SRC_DIR}/register_kernel.cc | |||||
| ${SRC_DIR}/lite_kernel.cc | ${SRC_DIR}/lite_kernel.cc | ||||
| ${SRC_DIR}/lite_kernel_util.cc | ${SRC_DIR}/lite_kernel_util.cc | ||||
| ${SRC_DIR}/scheduler.cc | ${SRC_DIR}/scheduler.cc | ||||