Browse Source

!15407 [MS][LITE][STABLE]support third vendor reg kernel

From: @jpc_chenjianping
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
pull/15407/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
40ecc3233f
14 changed files with 400 additions and 20 deletions
  1. +3
    -0
      mindspore/lite/src/CMakeLists.txt
  2. +4
    -0
      mindspore/lite/src/common/tensor_util.cc
  3. +1
    -0
      mindspore/lite/src/common/tensor_util.h
  4. +30
    -0
      mindspore/lite/src/kernel_interface.cc
  5. +69
    -0
      mindspore/lite/src/kernel_interface.h
  6. +50
    -0
      mindspore/lite/src/kernel_interface_registry.cc
  7. +43
    -0
      mindspore/lite/src/kernel_interface_registry.h
  8. +98
    -15
      mindspore/lite/src/kernel_registry.cc
  9. +10
    -4
      mindspore/lite/src/kernel_registry.h
  10. +8
    -1
      mindspore/lite/src/lite_kernel.h
  11. +31
    -0
      mindspore/lite/src/register_kernel.cc
  12. +51
    -0
      mindspore/lite/src/register_kernel.h
  13. +1
    -0
      mindspore/lite/test/CMakeLists.txt
  14. +1
    -0
      mindspore/lite/tools/converter/CMakeLists.txt

+ 3
- 0
mindspore/lite/src/CMakeLists.txt View File

@@ -60,6 +60,9 @@ set(LITE_SRC
${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc ${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_model.cc ${CMAKE_CURRENT_SOURCE_DIR}/lite_model.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc ${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc
${CMAKE_CURRENT_SOURCE_DIR}/register_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel_interface.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel_interface_registry.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc


+ 4
- 0
mindspore/lite/src/common/tensor_util.cc View File

@@ -15,6 +15,7 @@
*/ */


#include "src/common/tensor_util.h" #include "src/common/tensor_util.h"
#include <algorithm>
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
@@ -226,5 +227,8 @@ int CheckTensorsInvalid(const std::vector<Tensor *> &tensors) {
return RET_OK; return RET_OK;
} }


void Tensor2MSTensor(const std::vector<Tensor *> &&tensors, std::vector<tensor::MSTensor *> *out_tensors) {
std::copy(tensors.begin(), tensors.end(), std::back_inserter(*out_tensors));
}
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 1
- 0
mindspore/lite/src/common/tensor_util.h View File

@@ -40,6 +40,7 @@ int GenerateOutTensorC(const OpParameter *const parameter, const std::vector<lit
std::vector<lite::Tensor *> *outputs, std::vector<TensorC *> *out_tensor_c); std::vector<lite::Tensor *> *outputs, std::vector<TensorC *> *out_tensor_c);


int CheckTensorsInvalid(const std::vector<Tensor *> &tensors); int CheckTensorsInvalid(const std::vector<Tensor *> &tensors);
void Tensor2MSTensor(const std::vector<Tensor *> &&tensors, std::vector<tensor::MSTensor *> *out_tensors);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore




+ 30
- 0
mindspore/lite/src/kernel_interface.cc View File

@@ -0,0 +1,30 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/kernel_interface.h"
#include "src/kernel_interface_registry.h"

namespace mindspore {
namespace kernel {
RegisterKernelInterface *RegisterKernelInterface::Instance() {
static RegisterKernelInterface instance;
return &instance;
}

int RegisterKernelInterface::Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator) {
return lite::KernelInterfaceRegistry::Instance()->Reg(vendor, op_type, creator);
}
} // namespace kernel
} // namespace mindspore

+ 69
- 0
mindspore/lite/src/kernel_interface.h View File

@@ -0,0 +1,69 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_
#define MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_

#include <string>
#include <vector>
#include "include/ms_tensor.h"
#include "schema/model_generated.h"

namespace mindspore {
namespace kernel {
struct CapabilityParam {
float exec_time_;
float power_usage_;
};

class KernelInterface {
public:
virtual ~KernelInterface() = default;
virtual int Infer(const std::vector<tensor::MSTensor *> &tensor_in, std::vector<tensor::MSTensor *> *outputs,
const schema::Primitive *primitive) {
return 0;
}

virtual int GetCapability(const std::vector<tensor::MSTensor *> &tensor_in, const schema::Primitive *primitive,
CapabilityParam *param) {
return 0;
}
};
typedef KernelInterface *(*KernelInterfaceCreator)();

class RegisterKernelInterface {
public:
static RegisterKernelInterface *Instance();
int Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator);

private:
RegisterKernelInterface() = default;
};

class KernelInterfaceReg {
public:
KernelInterfaceReg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator) {
RegisterKernelInterface::Instance()->Reg(vendor, op_type, creator);
}
~KernelInterfaceReg() = default;
};

#define REGISTER_KERNEL_INTERFACE(vendor, op_type, creator) \
static KernelInterfaceReg g_##vendor##op_type##_inter_reg(vendor, op_type, creator);
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_

+ 50
- 0
mindspore/lite/src/kernel_interface_registry.cc View File

@@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/kernel_interface_registry.h"
#include "src/kernel_interface.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"

using mindspore::kernel::KernelInterfaceCreator;
using mindspore::schema::PrimitiveType_MAX;
using mindspore::schema::PrimitiveType_MIN;
namespace mindspore {
namespace lite {
namespace {
static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN + 1;
}

int KernelInterfaceRegistry::Reg(const std::string &vendor, const int &op_type, KernelInterfaceCreator creator) {
auto vendor_hash = std::hash<std::string>{}(vendor);
auto iter = kernel_interfaces_.find(vendor_hash);
if (iter == kernel_interfaces_.end()) {
kernel_interfaces_[vendor_hash] =
reinterpret_cast<KernelInterfaceCreator *>(malloc(kMaxKernelNum * sizeof(KernelInterfaceCreator)));
if (kernel_interfaces_[vendor_hash] == nullptr) {
MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!";
return RET_ERROR;
}
}
if (op_type < PrimitiveType_MIN || op_type > kMaxKernelNum) {
MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << kMaxKernelNum;
return RET_ERROR;
}
kernel_interfaces_[vendor_hash][op_type] = creator;
return RET_OK;
}

} // namespace lite
} // namespace mindspore

+ 43
- 0
mindspore/lite/src/kernel_interface_registry.h View File

@@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_
#define MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_

#include <string>
#include <unordered_map>
#include "src/kernel_interface.h"

namespace mindspore {
namespace lite {
class KernelInterfaceRegistry {
public:
static KernelInterfaceRegistry *Instance() {
static KernelInterfaceRegistry instance;
return &instance;
}

int Reg(const std::string &vendor, const int &op_type, kernel::KernelInterfaceCreator creator);

private:
KernelInterfaceRegistry() = default;

std::unordered_map<size_t, kernel::KernelInterfaceCreator *> kernel_interfaces_;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_

+ 98
- 15
mindspore/lite/src/kernel_registry.cc View File

@@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include <utility>
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/ops/populate/populate_register.h" #include "src/ops/populate/populate_register.h"
#include "src/common/version_manager.h" #include "src/common/version_manager.h"
@@ -26,13 +27,20 @@
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#endif #endif
#include "src/common/tensor_util.h"


using mindspore::kernel::CreateKernel;
using mindspore::kernel::kBuiltin;
using mindspore::kernel::kCPU; using mindspore::kernel::kCPU;
using mindspore::kernel::KERNEL_ARCH; using mindspore::kernel::KERNEL_ARCH;
using mindspore::kernel::KernelCreator; using mindspore::kernel::KernelCreator;
using mindspore::kernel::KernelKey; using mindspore::kernel::KernelKey;


namespace mindspore::lite { namespace mindspore::lite {
namespace {
static const int kKernelMaxNum = (kNumberTypeEnd - kNumberTypeBegin + 1) * (PrimitiveType_MAX - PrimitiveType_MIN + 1);
} // namespace

KernelRegistry *KernelRegistry::GetInstance() { KernelRegistry *KernelRegistry::GetInstance() {
static KernelRegistry instance; static KernelRegistry instance;


@@ -47,6 +55,47 @@ KernelRegistry *KernelRegistry::GetInstance() {
return &instance; return &instance;
} }


int KernelRegistry::GetFuncIndex(const kernel::KernelKey &desc) {
int dType_index = static_cast<int>(desc.data_type) - kNumberTypeBegin;
return dType_index * op_type_length_ + desc.type;
}

int KernelRegistry::RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type,
const int type, kernel::CreateKernel creator) {
auto vendor_hash = std::hash<std::string>{}(vendor);
auto arch_hash = std::hash<std::string>{}(arch);
auto iter = kernel_creators_.find(vendor_hash);
if (iter == kernel_creators_.end()) {
all_vendors_.insert(vendor);
kernel_creators_[vendor_hash][arch_hash] =
reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel)));
if (kernel_creators_[vendor_hash][arch_hash] == nullptr) {
MS_LOG(ERROR) << "malloc kernel creator buffer fail! vendor: " << vendor << ",arch:" << arch;
return RET_ERROR;
}
memset(kernel_creators_[vendor_hash][arch_hash], 0, kKernelMaxNum * sizeof(CreateKernel));
} else {
auto iter_arch = iter->second.find(arch_hash);
if (iter_arch == iter->second.end()) {
iter->second[arch_hash] = reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel)));
if (iter->second[arch_hash] == nullptr) {
MS_LOG(ERROR) << "malloc kernel creator buffer fail! vendor: " << vendor << ",arch:" << arch;
return RET_ERROR;
}
memset(iter->second[arch_hash], 0, kKernelMaxNum * sizeof(CreateKernel));
}
}

KernelKey desc = {kCPU, data_type, type, arch, vendor};
int index = GetFuncIndex(desc);
if (index >= kKernelMaxNum) {
MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << data_type << ",op type " << type;
return RET_ERROR;
}
kernel_creators_[vendor_hash][arch_hash][index] = creator;
return RET_OK;
}

int KernelRegistry::Init() { int KernelRegistry::Init() {
#ifdef ENABLE_ARM64 #ifdef ENABLE_ARM64
if (mindspore::lite::IsSupportSDot()) { if (mindspore::lite::IsSupportSDot()) {
@@ -66,17 +115,38 @@ int KernelRegistry::Init() {
} }


kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
int index = GetCreatorFuncIndex(desc);
if (index >= array_size_ || index < 0) {
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type "
<< desc.type;
if (desc.vendor == kBuiltin) {
int index = GetCreatorFuncIndex(desc);
if (index >= array_size_ || index < 0) {
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type "
<< desc.type;
return nullptr;
}
return creator_arrays_[index];
}
MS_LOG(ERROR) << "Call wrong interface!vendor: " << desc.vendor;
return nullptr;
}

kernel::CreateKernel KernelRegistry::GetDelegateCreator(const kernel::KernelKey &desc) {
auto vendor_hash = std::hash<std::string>{}(desc.vendor);
auto it_by_vendor = kernel_creators_.find(vendor_hash);
if (it_by_vendor == kernel_creators_.end()) {
return nullptr; return nullptr;
} }
auto it = creator_arrays_[index];
if (it != nullptr) {
return it;
auto arch_hash = std::hash<std::string>{}(desc.kernel_arch);
auto it_by_arch = it_by_vendor->second.find(arch_hash);
if (it_by_arch == it_by_vendor->second.end()) {
return nullptr;
} }
return nullptr;
auto index = GetFuncIndex(desc);
if (index < 0 || index >= kKernelMaxNum) {
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.kernel_arch << ", data_type" << desc.data_type << ",op type "
<< desc.type << ", vendor: " << desc.vendor;
return nullptr;
}

return it_by_arch->second[index];
} }


int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) { int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) {
@@ -127,15 +197,28 @@ bool KernelRegistry::SupportKernel(const KernelKey &key) {


kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors, kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors,
const std::vector<Tensor *> &out_tensors, const InnerContext *ctx, const std::vector<Tensor *> &out_tensors, const InnerContext *ctx,
const kernel::KernelKey &key, OpParameter *parameter) {
const kernel::KernelKey &key, OpParameter *parameter,
const void *primitive) {
MS_ASSERT(ctx != nullptr); MS_ASSERT(ctx != nullptr);
auto creator = GetCreator(key);
if (creator != nullptr) {
auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key);
if (kernel != nullptr) {
kernel->set_desc(key);
return kernel;
if (key.vendor == kBuiltin) {
auto creator = GetCreator(key);
if (creator != nullptr) {
auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key);
if (kernel != nullptr) {
kernel->set_desc(key);
return kernel;
}
}
} else {
auto creator = GetDelegateCreator(key);
if (creator == nullptr) {
return nullptr;
} }
std::vector<tensor::MSTensor *> tensors_in;
Tensor2MSTensor(std::move(in_tensors), &tensors_in);
std::vector<tensor::MSTensor *> tensors_out;
Tensor2MSTensor(std::move(out_tensors), &tensors_out);
return creator(tensors_in, tensors_out, static_cast<const schema::Primitive *>(primitive), ctx);
} }
return nullptr; return nullptr;
} }


+ 10
- 4
mindspore/lite/src/kernel_registry.h View File

@@ -20,7 +20,9 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <set>
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
#include "src/register_kernel.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"


using mindspore::kernel::kKernelArch_MAX; using mindspore::kernel::kKernelArch_MAX;
@@ -37,16 +39,18 @@ class KernelRegistry {
static KernelRegistry *GetInstance(); static KernelRegistry *GetInstance();
static int Init(); static int Init();
virtual kernel::KernelCreator GetCreator(const kernel::KernelKey &desc); virtual kernel::KernelCreator GetCreator(const kernel::KernelKey &desc);
virtual kernel::CreateKernel GetDelegateCreator(const kernel::KernelKey &desc);
int GetCreatorFuncIndex(kernel::KernelKey desc); int GetCreatorFuncIndex(kernel::KernelKey desc);
int GetFuncIndex(const kernel::KernelKey &desc);
void RegKernel(kernel::KernelKey desc, kernel::KernelCreator creator); void RegKernel(kernel::KernelKey desc, kernel::KernelCreator creator);
void RegKernel(kernel::KERNEL_ARCH arch, TypeId data_type, int type, kernel::KernelCreator creator); void RegKernel(kernel::KERNEL_ARCH arch, TypeId data_type, int type, kernel::KernelCreator creator);
int RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, const int type,
kernel::CreateKernel creator);
bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators); bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators);
int GetKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *op_parameter,
kernel::LiteKernel **kernel);
bool SupportKernel(const kernel::KernelKey &key); bool SupportKernel(const kernel::KernelKey &key);
kernel::LiteKernel *GetKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, kernel::LiteKernel *GetKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *op_parameter);
const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *op_parameter,
const void *primitive = nullptr);


protected: protected:
static const int device_type_length_{kKernelArch_MAX - kKernelArch_MIN + 1}; static const int device_type_length_{kKernelArch_MAX - kKernelArch_MIN + 1};
@@ -54,6 +58,8 @@ class KernelRegistry {
static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1}; static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1};
static const int array_size_{device_type_length_ * data_type_length_ * op_type_length_}; static const int array_size_{device_type_length_ * data_type_length_ * op_type_length_};
kernel::KernelCreator *creator_arrays_ = nullptr; kernel::KernelCreator *creator_arrays_ = nullptr;
std::unordered_map<std::size_t, std::unordered_map<std::size_t, kernel::CreateKernel *>> kernel_creators_;
std::set<std::string> all_vendors_;


private: private:
std::mutex lock_; std::mutex lock_;


+ 8
- 1
mindspore/lite/src/lite_kernel.h View File

@@ -34,14 +34,21 @@


namespace mindspore::kernel { namespace mindspore::kernel {
enum KERNEL_ARCH { kCPU, kGPU, kAPU, kNPU, kKernelArch_MIN = kCPU, kKernelArch_MAX = kNPU }; enum KERNEL_ARCH { kCPU, kGPU, kAPU, kNPU, kKernelArch_MIN = kCPU, kKernelArch_MAX = kNPU };
static const char *const kBuiltin = "Builtin";


struct KernelKey { struct KernelKey {
KERNEL_ARCH arch; KERNEL_ARCH arch;
TypeId data_type; TypeId data_type;
int type; int type;
std::string kernel_arch;
std::string vendor{kBuiltin};


bool operator<(const KernelKey &dst) const { bool operator<(const KernelKey &dst) const {
if (arch != dst.arch) {
if (vendor != dst.vendor) {
return vendor < dst.vendor;
} else if (kernel_arch != dst.kernel_arch) {
return kernel_arch < dst.kernel_arch;
} else if (arch != dst.arch) {
return arch < dst.arch; return arch < dst.arch;
} else if (data_type != dst.data_type) { } else if (data_type != dst.data_type) {
return data_type < dst.data_type; return data_type < dst.data_type;


+ 31
- 0
mindspore/lite/src/register_kernel.cc View File

@@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/register_kernel.h"
#include "src/kernel_registry.h"

namespace mindspore {
namespace kernel {
RegisterKernel *RegisterKernel::GetInstance() {
static RegisterKernel instance;
return &instance;
}

int RegisterKernel::RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type,
const int op_type, CreateKernel creator) {
return lite::KernelRegistry::GetInstance()->RegKernel(arch, vendor, data_type, op_type, creator);
}
} // namespace kernel
} // namespace mindspore

+ 51
- 0
mindspore/lite/src/register_kernel.h View File

@@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_REGISTER_KERNEL_H_
#define MINDSPORE_LITE_SRC_REGISTER_KERNEL_H_

#include <string>
#include <vector>
#include "src/lite_kernel.h"

namespace mindspore {
namespace kernel {
typedef kernel::LiteKernel *(*CreateKernel)(const std::vector<tensor::MSTensor *> &inputs,
const std::vector<tensor::MSTensor *> &outputs,
const schema::Primitive *primitive, const lite::Context *ctx);
class RegisterKernel {
public:
static RegisterKernel *GetInstance();
int RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, const int type,
CreateKernel creator);
};

class KernelReg {
public:
~KernelReg() = default;

KernelReg(const std::string &arch, const std::string &vendor, const TypeId data_type, const int op_type,
CreateKernel creator) {
RegisterKernel::GetInstance()->RegKernel(arch, vendor, data_type, op_type, creator);
}
};

#define REGISTER_KERNEL(arch, vendor, data_type, op_type, creator) \
static KernelReg g_##arch##vendor##data_type##op_type##kernelReg(arch, vendor, data_type, op_type, creator);
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_LITE_SRC_REGISTER_KERNEL_H_

+ 1
- 0
mindspore/lite/test/CMakeLists.txt View File

@@ -142,6 +142,7 @@ set(TEST_LITE_SRC
${LITE_DIR}/src/executor.cc ${LITE_DIR}/src/executor.cc
${LITE_DIR}/src/inner_context.cc ${LITE_DIR}/src/inner_context.cc
${LITE_DIR}/src/kernel_registry.cc ${LITE_DIR}/src/kernel_registry.cc
${LITE_DIR}/src/register_kernel.cc
${LITE_DIR}/src/lite_kernel.cc ${LITE_DIR}/src/lite_kernel.cc
${LITE_DIR}/src/lite_kernel_util.cc ${LITE_DIR}/src/lite_kernel_util.cc
${LITE_DIR}/src/lite_session.cc ${LITE_DIR}/src/lite_session.cc


+ 1
- 0
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -111,6 +111,7 @@ set(LITE_SRC
${SRC_DIR}/ms_tensor.cc ${SRC_DIR}/ms_tensor.cc
${SRC_DIR}/tensorlist.cc ${SRC_DIR}/tensorlist.cc
${SRC_DIR}/kernel_registry.cc ${SRC_DIR}/kernel_registry.cc
${SRC_DIR}/register_kernel.cc
${SRC_DIR}/lite_kernel.cc ${SRC_DIR}/lite_kernel.cc
${SRC_DIR}/lite_kernel_util.cc ${SRC_DIR}/lite_kernel_util.cc
${SRC_DIR}/scheduler.cc ${SRC_DIR}/scheduler.cc


Loading…
Cancel
Save