From: @jpc_chenjianping Reviewed-by: @zhang_xue_tong,@zhanghaibo5 Signed-off-by: @zhang_xue_tongpull/15407/MERGE
| @@ -60,6 +60,9 @@ set(LITE_SRC | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_model.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/register_kernel.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/kernel_interface.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/kernel_interface_registry.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel_util.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include "src/common/tensor_util.h" | |||
| #include <algorithm> | |||
| #include "schema/model_generated.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/common/log_adapter.h" | |||
| @@ -226,5 +227,8 @@ int CheckTensorsInvalid(const std::vector<Tensor *> &tensors) { | |||
| return RET_OK; | |||
| } | |||
| void Tensor2MSTensor(const std::vector<Tensor *> &&tensors, std::vector<tensor::MSTensor *> *out_tensors) { | |||
| std::copy(tensors.begin(), tensors.end(), std::back_inserter(*out_tensors)); | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -40,6 +40,7 @@ int GenerateOutTensorC(const OpParameter *const parameter, const std::vector<lit | |||
| std::vector<lite::Tensor *> *outputs, std::vector<TensorC *> *out_tensor_c); | |||
| int CheckTensorsInvalid(const std::vector<Tensor *> &tensors); | |||
| void Tensor2MSTensor(const std::vector<Tensor *> &&tensors, std::vector<tensor::MSTensor *> *out_tensors); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/kernel_interface.h" | |||
| #include "src/kernel_interface_registry.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| RegisterKernelInterface *RegisterKernelInterface::Instance() { | |||
| static RegisterKernelInterface instance; | |||
| return &instance; | |||
| } | |||
| int RegisterKernelInterface::Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator) { | |||
| return lite::KernelInterfaceRegistry::Instance()->Reg(vendor, op_type, creator); | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,69 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_ | |||
| #define MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "include/ms_tensor.h" | |||
| #include "schema/model_generated.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| struct CapabilityParam { | |||
| float exec_time_; | |||
| float power_usage_; | |||
| }; | |||
| class KernelInterface { | |||
| public: | |||
| virtual ~KernelInterface() = default; | |||
| virtual int Infer(const std::vector<tensor::MSTensor *> &tensor_in, std::vector<tensor::MSTensor *> *outputs, | |||
| const schema::Primitive *primitive) { | |||
| return 0; | |||
| } | |||
| virtual int GetCapability(const std::vector<tensor::MSTensor *> &tensor_in, const schema::Primitive *primitive, | |||
| CapabilityParam *param) { | |||
| return 0; | |||
| } | |||
| }; | |||
| typedef KernelInterface *(*KernelInterfaceCreator)(); | |||
| class RegisterKernelInterface { | |||
| public: | |||
| static RegisterKernelInterface *Instance(); | |||
| int Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator); | |||
| private: | |||
| RegisterKernelInterface() = default; | |||
| }; | |||
| class KernelInterfaceReg { | |||
| public: | |||
| KernelInterfaceReg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator) { | |||
| RegisterKernelInterface::Instance()->Reg(vendor, op_type, creator); | |||
| } | |||
| ~KernelInterfaceReg() = default; | |||
| }; | |||
| #define REGISTER_KERNEL_INTERFACE(vendor, op_type, creator) \ | |||
| static KernelInterfaceReg g_##vendor##op_type##_inter_reg(vendor, op_type, creator); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_ | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/kernel_interface_registry.h" | |||
| #include "src/kernel_interface.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/common/log_adapter.h" | |||
| using mindspore::kernel::KernelInterfaceCreator; | |||
| using mindspore::schema::PrimitiveType_MAX; | |||
| using mindspore::schema::PrimitiveType_MIN; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace { | |||
| static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN + 1; | |||
| } | |||
| int KernelInterfaceRegistry::Reg(const std::string &vendor, const int &op_type, KernelInterfaceCreator creator) { | |||
| auto vendor_hash = std::hash<std::string>{}(vendor); | |||
| auto iter = kernel_interfaces_.find(vendor_hash); | |||
| if (iter == kernel_interfaces_.end()) { | |||
| kernel_interfaces_[vendor_hash] = | |||
| reinterpret_cast<KernelInterfaceCreator *>(malloc(kMaxKernelNum * sizeof(KernelInterfaceCreator))); | |||
| if (kernel_interfaces_[vendor_hash] == nullptr) { | |||
| MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| if (op_type < PrimitiveType_MIN || op_type > kMaxKernelNum) { | |||
| MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << kMaxKernelNum; | |||
| return RET_ERROR; | |||
| } | |||
| kernel_interfaces_[vendor_hash][op_type] = creator; | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_ | |||
| #define MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "src/kernel_interface.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class KernelInterfaceRegistry { | |||
| public: | |||
| static KernelInterfaceRegistry *Instance() { | |||
| static KernelInterfaceRegistry instance; | |||
| return &instance; | |||
| } | |||
| int Reg(const std::string &vendor, const int &op_type, kernel::KernelInterfaceCreator creator); | |||
| private: | |||
| KernelInterfaceRegistry() = default; | |||
| std::unordered_map<size_t, kernel::KernelInterfaceCreator *> kernel_interfaces_; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_ | |||
| @@ -14,6 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/kernel_registry.h" | |||
| #include <utility> | |||
| #include "include/errorcode.h" | |||
| #include "src/ops/populate/populate_register.h" | |||
| #include "src/common/version_manager.h" | |||
| @@ -26,13 +27,20 @@ | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/common/utils.h" | |||
| #endif | |||
| #include "src/common/tensor_util.h" | |||
| using mindspore::kernel::CreateKernel; | |||
| using mindspore::kernel::kBuiltin; | |||
| using mindspore::kernel::kCPU; | |||
| using mindspore::kernel::KERNEL_ARCH; | |||
| using mindspore::kernel::KernelCreator; | |||
| using mindspore::kernel::KernelKey; | |||
| namespace mindspore::lite { | |||
| namespace { | |||
| static const int kKernelMaxNum = (kNumberTypeEnd - kNumberTypeBegin + 1) * (PrimitiveType_MAX - PrimitiveType_MIN + 1); | |||
| } // namespace | |||
| KernelRegistry *KernelRegistry::GetInstance() { | |||
| static KernelRegistry instance; | |||
| @@ -47,6 +55,47 @@ KernelRegistry *KernelRegistry::GetInstance() { | |||
| return &instance; | |||
| } | |||
| int KernelRegistry::GetFuncIndex(const kernel::KernelKey &desc) { | |||
| int dType_index = static_cast<int>(desc.data_type) - kNumberTypeBegin; | |||
| return dType_index * op_type_length_ + desc.type; | |||
| } | |||
| int KernelRegistry::RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, | |||
| const int type, kernel::CreateKernel creator) { | |||
| auto vendor_hash = std::hash<std::string>{}(vendor); | |||
| auto arch_hash = std::hash<std::string>{}(arch); | |||
| auto iter = kernel_creators_.find(vendor_hash); | |||
| if (iter == kernel_creators_.end()) { | |||
| all_vendors_.insert(vendor); | |||
| kernel_creators_[vendor_hash][arch_hash] = | |||
| reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel))); | |||
| if (kernel_creators_[vendor_hash][arch_hash] == nullptr) { | |||
| MS_LOG(ERROR) << "malloc kernel creator buffer fail! vendor: " << vendor << ",arch:" << arch; | |||
| return RET_ERROR; | |||
| } | |||
| memset(kernel_creators_[vendor_hash][arch_hash], 0, kKernelMaxNum * sizeof(CreateKernel)); | |||
| } else { | |||
| auto iter_arch = iter->second.find(arch_hash); | |||
| if (iter_arch == iter->second.end()) { | |||
| iter->second[arch_hash] = reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel))); | |||
| if (iter->second[arch_hash] == nullptr) { | |||
| MS_LOG(ERROR) << "malloc kernel creator buffer fail! vendor: " << vendor << ",arch:" << arch; | |||
| return RET_ERROR; | |||
| } | |||
| memset(iter->second[arch_hash], 0, kKernelMaxNum * sizeof(CreateKernel)); | |||
| } | |||
| } | |||
| KernelKey desc = {kCPU, data_type, type, arch, vendor}; | |||
| int index = GetFuncIndex(desc); | |||
| if (index >= kKernelMaxNum) { | |||
| MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << data_type << ",op type " << type; | |||
| return RET_ERROR; | |||
| } | |||
| kernel_creators_[vendor_hash][arch_hash][index] = creator; | |||
| return RET_OK; | |||
| } | |||
| int KernelRegistry::Init() { | |||
| #ifdef ENABLE_ARM64 | |||
| if (mindspore::lite::IsSupportSDot()) { | |||
| @@ -66,17 +115,38 @@ int KernelRegistry::Init() { | |||
| } | |||
| kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { | |||
| int index = GetCreatorFuncIndex(desc); | |||
| if (index >= array_size_ || index < 0) { | |||
| MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type " | |||
| << desc.type; | |||
| if (desc.vendor == kBuiltin) { | |||
| int index = GetCreatorFuncIndex(desc); | |||
| if (index >= array_size_ || index < 0) { | |||
| MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type " | |||
| << desc.type; | |||
| return nullptr; | |||
| } | |||
| return creator_arrays_[index]; | |||
| } | |||
| MS_LOG(ERROR) << "Call wrong interface!vendor: " << desc.vendor; | |||
| return nullptr; | |||
| } | |||
| kernel::CreateKernel KernelRegistry::GetDelegateCreator(const kernel::KernelKey &desc) { | |||
| auto vendor_hash = std::hash<std::string>{}(desc.vendor); | |||
| auto it_by_vendor = kernel_creators_.find(vendor_hash); | |||
| if (it_by_vendor == kernel_creators_.end()) { | |||
| return nullptr; | |||
| } | |||
| auto it = creator_arrays_[index]; | |||
| if (it != nullptr) { | |||
| return it; | |||
| auto arch_hash = std::hash<std::string>{}(desc.kernel_arch); | |||
| auto it_by_arch = it_by_vendor->second.find(arch_hash); | |||
| if (it_by_arch == it_by_vendor->second.end()) { | |||
| return nullptr; | |||
| } | |||
| return nullptr; | |||
| auto index = GetFuncIndex(desc); | |||
| if (index < 0 || index >= kKernelMaxNum) { | |||
| MS_LOG(ERROR) << "invalid kernel key, arch " << desc.kernel_arch << ", data_type" << desc.data_type << ",op type " | |||
| << desc.type << ", vendor: " << desc.vendor; | |||
| return nullptr; | |||
| } | |||
| return it_by_arch->second[index]; | |||
| } | |||
| int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) { | |||
| @@ -127,15 +197,28 @@ bool KernelRegistry::SupportKernel(const KernelKey &key) { | |||
| kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors, | |||
| const std::vector<Tensor *> &out_tensors, const InnerContext *ctx, | |||
| const kernel::KernelKey &key, OpParameter *parameter) { | |||
| const kernel::KernelKey &key, OpParameter *parameter, | |||
| const void *primitive) { | |||
| MS_ASSERT(ctx != nullptr); | |||
| auto creator = GetCreator(key); | |||
| if (creator != nullptr) { | |||
| auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key); | |||
| if (kernel != nullptr) { | |||
| kernel->set_desc(key); | |||
| return kernel; | |||
| if (key.vendor == kBuiltin) { | |||
| auto creator = GetCreator(key); | |||
| if (creator != nullptr) { | |||
| auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key); | |||
| if (kernel != nullptr) { | |||
| kernel->set_desc(key); | |||
| return kernel; | |||
| } | |||
| } | |||
| } else { | |||
| auto creator = GetDelegateCreator(key); | |||
| if (creator == nullptr) { | |||
| return nullptr; | |||
| } | |||
| std::vector<tensor::MSTensor *> tensors_in; | |||
| Tensor2MSTensor(std::move(in_tensors), &tensors_in); | |||
| std::vector<tensor::MSTensor *> tensors_out; | |||
| Tensor2MSTensor(std::move(out_tensors), &tensors_out); | |||
| return creator(tensors_in, tensors_out, static_cast<const schema::Primitive *>(primitive), ctx); | |||
| } | |||
| return nullptr; | |||
| } | |||
| @@ -20,7 +20,9 @@ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <set> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/register_kernel.h" | |||
| #include "schema/model_generated.h" | |||
| using mindspore::kernel::kKernelArch_MAX; | |||
| @@ -37,16 +39,18 @@ class KernelRegistry { | |||
| static KernelRegistry *GetInstance(); | |||
| static int Init(); | |||
| virtual kernel::KernelCreator GetCreator(const kernel::KernelKey &desc); | |||
| virtual kernel::CreateKernel GetDelegateCreator(const kernel::KernelKey &desc); | |||
| int GetCreatorFuncIndex(kernel::KernelKey desc); | |||
| int GetFuncIndex(const kernel::KernelKey &desc); | |||
| void RegKernel(kernel::KernelKey desc, kernel::KernelCreator creator); | |||
| void RegKernel(kernel::KERNEL_ARCH arch, TypeId data_type, int type, kernel::KernelCreator creator); | |||
| int RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, const int type, | |||
| kernel::CreateKernel creator); | |||
| bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators); | |||
| int GetKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | |||
| const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *op_parameter, | |||
| kernel::LiteKernel **kernel); | |||
| bool SupportKernel(const kernel::KernelKey &key); | |||
| kernel::LiteKernel *GetKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | |||
| const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *op_parameter); | |||
| const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *op_parameter, | |||
| const void *primitive = nullptr); | |||
| protected: | |||
| static const int device_type_length_{kKernelArch_MAX - kKernelArch_MIN + 1}; | |||
| @@ -54,6 +58,8 @@ class KernelRegistry { | |||
| static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1}; | |||
| static const int array_size_{device_type_length_ * data_type_length_ * op_type_length_}; | |||
| kernel::KernelCreator *creator_arrays_ = nullptr; | |||
| std::unordered_map<std::size_t, std::unordered_map<std::size_t, kernel::CreateKernel *>> kernel_creators_; | |||
| std::set<std::string> all_vendors_; | |||
| private: | |||
| std::mutex lock_; | |||
| @@ -34,14 +34,21 @@ | |||
| namespace mindspore::kernel { | |||
| enum KERNEL_ARCH { kCPU, kGPU, kAPU, kNPU, kKernelArch_MIN = kCPU, kKernelArch_MAX = kNPU }; | |||
| static const char *const kBuiltin = "Builtin"; | |||
| struct KernelKey { | |||
| KERNEL_ARCH arch; | |||
| TypeId data_type; | |||
| int type; | |||
| std::string kernel_arch; | |||
| std::string vendor{kBuiltin}; | |||
| bool operator<(const KernelKey &dst) const { | |||
| if (arch != dst.arch) { | |||
| if (vendor != dst.vendor) { | |||
| return vendor < dst.vendor; | |||
| } else if (kernel_arch != dst.kernel_arch) { | |||
| return kernel_arch < dst.kernel_arch; | |||
| } else if (arch != dst.arch) { | |||
| return arch < dst.arch; | |||
| } else if (data_type != dst.data_type) { | |||
| return data_type < dst.data_type; | |||
| @@ -0,0 +1,31 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/register_kernel.h" | |||
| #include "src/kernel_registry.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| RegisterKernel *RegisterKernel::GetInstance() { | |||
| static RegisterKernel instance; | |||
| return &instance; | |||
| } | |||
| int RegisterKernel::RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, | |||
| const int op_type, CreateKernel creator) { | |||
| return lite::KernelRegistry::GetInstance()->RegKernel(arch, vendor, data_type, op_type, creator); | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_REGISTER_KERNEL_H_ | |||
| #define MINDSPORE_LITE_SRC_REGISTER_KERNEL_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| typedef kernel::LiteKernel *(*CreateKernel)(const std::vector<tensor::MSTensor *> &inputs, | |||
| const std::vector<tensor::MSTensor *> &outputs, | |||
| const schema::Primitive *primitive, const lite::Context *ctx); | |||
| class RegisterKernel { | |||
| public: | |||
| static RegisterKernel *GetInstance(); | |||
| int RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, const int type, | |||
| CreateKernel creator); | |||
| }; | |||
| class KernelReg { | |||
| public: | |||
| ~KernelReg() = default; | |||
| KernelReg(const std::string &arch, const std::string &vendor, const TypeId data_type, const int op_type, | |||
| CreateKernel creator) { | |||
| RegisterKernel::GetInstance()->RegKernel(arch, vendor, data_type, op_type, creator); | |||
| } | |||
| }; | |||
| #define REGISTER_KERNEL(arch, vendor, data_type, op_type, creator) \ | |||
| static KernelReg g_##arch##vendor##data_type##op_type##kernelReg(arch, vendor, data_type, op_type, creator); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_REGISTER_KERNEL_H_ | |||
| @@ -142,6 +142,7 @@ set(TEST_LITE_SRC | |||
| ${LITE_DIR}/src/executor.cc | |||
| ${LITE_DIR}/src/inner_context.cc | |||
| ${LITE_DIR}/src/kernel_registry.cc | |||
| ${LITE_DIR}/src/register_kernel.cc | |||
| ${LITE_DIR}/src/lite_kernel.cc | |||
| ${LITE_DIR}/src/lite_kernel_util.cc | |||
| ${LITE_DIR}/src/lite_session.cc | |||
| @@ -111,6 +111,7 @@ set(LITE_SRC | |||
| ${SRC_DIR}/ms_tensor.cc | |||
| ${SRC_DIR}/tensorlist.cc | |||
| ${SRC_DIR}/kernel_registry.cc | |||
| ${SRC_DIR}/register_kernel.cc | |||
| ${SRC_DIR}/lite_kernel.cc | |||
| ${SRC_DIR}/lite_kernel_util.cc | |||
| ${SRC_DIR}/scheduler.cc | |||