| @@ -19,13 +19,14 @@ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <map> | |||
| #include "schema/model_generated.h" | |||
| #include "include/api/types.h" | |||
| #include "include/api/context.h" | |||
| namespace mindspore::kernel { | |||
| /// \brief The Kernel class is used to define a MindSpore Kernel. | |||
| class Kernel { | |||
| class MS_API Kernel { | |||
| public: | |||
| Kernel() = default; | |||
| /// \brief Constructor. | |||
| @@ -37,9 +38,7 @@ class Kernel { | |||
| Kernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs, | |||
| const schema::Primitive *primitive, const mindspore::Context *ctx) | |||
| : context_(ctx), inputs_(std::move(inputs)), outputs_(std::move(outputs)), primitive_(primitive) { | |||
| if (primitive != nullptr) { | |||
| type_ = primitive->value_type(); | |||
| } | |||
| Initialize(); | |||
| } | |||
| /// \brief Destructor. | |||
| virtual ~Kernel() = default; | |||
| @@ -101,6 +100,23 @@ class Kernel { | |||
| /// | |||
| /// \return the primitive of kernel generated by flatbuffers. | |||
| const schema::Primitive *primitive() const { return this->primitive_; } | |||
| /// \brief get kernel's attribute | |||
| /// | |||
| /// \param[in] key define the kernel's attribute key. | |||
| std::string GetAttr(const std::string &key) const { | |||
| auto iter = attrs_.find(key); | |||
| if (iter != attrs_.end()) { | |||
| return iter->second; | |||
| } | |||
| return ""; | |||
| } | |||
| protected: | |||
| /// \brief set kernel's attribute | |||
| /// | |||
| /// \param[in] key define the kernel's attribute key. | |||
| /// \param[in] value define the kernel's attribute value. | |||
| void SetAttr(const std::string &key, const std::string &value) { attrs_[key] = value; } | |||
| protected: | |||
| std::string name_; | |||
| @@ -109,6 +125,10 @@ class Kernel { | |||
| std::vector<mindspore::MSTensor> outputs_; | |||
| schema::PrimitiveType type_ = schema::PrimitiveType_NONE; | |||
| const schema::Primitive *primitive_ = nullptr; | |||
| std::map<std::string, std::string> attrs_; | |||
| private: | |||
| void Initialize(); | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -25,6 +25,9 @@ | |||
| #include "schema/model_generated.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class Kernel; | |||
| } | |||
| namespace registry { | |||
| /// \brief KernelInterfaceCreator defined a functor to create KernelInterface. | |||
| using KernelInterfaceCreator = std::function<std::shared_ptr<kernel::KernelInterface>()>; | |||
| @@ -55,10 +58,12 @@ class MS_API RegisterKernelInterface { | |||
| /// | |||
| /// \param[in] provider Define the identification of user. | |||
| /// \param[in] primitive Define the attributes of a certain op. | |||
| /// \param[in] kernel Define the kernel of a certain op. | |||
| /// | |||
| /// \return Boolean value to represent registration of a certain op is existing or not. | |||
| static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider, | |||
| const schema::Primitive *primitive); | |||
| const schema::Primitive *primitive, | |||
| const kernel::Kernel *kernel = nullptr); | |||
| }; | |||
| /// \brief KernelInterfaceReg defined registration class of KernelInterface. | |||
| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "include/api/kernel.h" | |||
| namespace mindspore::kernel { | |||
| void Kernel::Initialize() { | |||
| if (primitive_ == nullptr) { | |||
| return; | |||
| } | |||
| type_ = primitive_->value_type(); | |||
| if (type_ == schema::PrimitiveType_Custom) { | |||
| auto param = primitive_->value_as_Custom(); | |||
| if (param != nullptr && param->type() != nullptr) { | |||
| SetAttr("type", param->type()->str()); | |||
| } | |||
| } | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -20,6 +20,7 @@ | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/common/version_manager.h" | |||
| #include "schema/model_generated.h" | |||
| #include "include/api/kernel.h" | |||
| using mindspore::registry::KernelInterfaceCreator; | |||
| using mindspore::schema::PrimitiveType_MAX; | |||
| @@ -32,12 +33,10 @@ static constexpr auto KMaxCustomTypeNum = 200; | |||
| static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN + 1; | |||
| std::string GetCustomType(const schema::Primitive *primitive) { | |||
| auto param = primitive->value_as_Custom(); | |||
| if (param == nullptr) { | |||
| return ""; | |||
| } | |||
| if (param->type() == nullptr) { | |||
| if (param == nullptr || param->type() == nullptr) { | |||
| return ""; | |||
| } | |||
| return param->type()->str(); | |||
| } | |||
| } // namespace | |||
| @@ -92,15 +91,19 @@ std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomCache | |||
| } | |||
| std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomKernelInterface( | |||
| const schema::Primitive *primitive) { | |||
| MS_ASSERT(primitive != nullptr); | |||
| const schema::Primitive *primitive, const kernel::Kernel *kernel) { | |||
| std::unique_lock<std::mutex> lock(mutex_); | |||
| auto &&type = GetCustomType(primitive); | |||
| std::string type; | |||
| if (kernel == nullptr) { | |||
| type = GetCustomType(primitive); | |||
| } else { | |||
| type = kernel->GetAttr("type"); | |||
| } | |||
| for (auto &&item : custom_creators_) { | |||
| auto &&provider = item.first; | |||
| auto kernel = GetCustomCacheInterface(provider, type); | |||
| if (kernel != nullptr) { | |||
| return kernel; | |||
| auto kernel_interface = GetCustomCacheInterface(provider, type); | |||
| if (kernel_interface != nullptr) { | |||
| return kernel_interface; | |||
| } | |||
| auto provider_iter = custom_creators_.find(provider); | |||
| if (provider_iter == custom_creators_.end()) { | |||
| @@ -108,32 +111,38 @@ std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomKerne | |||
| } | |||
| auto creator_iter = provider_iter->second.find(type); | |||
| if (creator_iter != provider_iter->second.end()) { | |||
| kernel = creator_iter->second(); | |||
| custom_kernels_[provider][type] = kernel; | |||
| return kernel; | |||
| kernel_interface = creator_iter->second(); | |||
| custom_kernels_[provider][type] = kernel_interface; | |||
| return kernel_interface; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInterface( | |||
| const std::string &provider, const schema::Primitive *primitive) { | |||
| if (primitive == nullptr) { | |||
| std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInterface(const std::string &provider, | |||
| const schema::Primitive *primitive, | |||
| const kernel::Kernel *kernel) { | |||
| if (primitive == nullptr && kernel == nullptr) { | |||
| return nullptr; | |||
| } | |||
| int op_type = static_cast<int>(primitive->value_type()); | |||
| int op_type; | |||
| if (kernel == nullptr) { | |||
| op_type = static_cast<int>(primitive->value_type()); | |||
| } else { | |||
| op_type = static_cast<int>(kernel->type()); | |||
| } | |||
| if (op_type > PrimitiveType_MAX || op_type <= PrimitiveType_MIN) { | |||
| return nullptr; | |||
| } | |||
| if (op_type == schema::PrimitiveType_Custom) { | |||
| return GetCustomKernelInterface(primitive); | |||
| return GetCustomKernelInterface(primitive, kernel); | |||
| } | |||
| std::unique_lock<std::mutex> lock(mutex_); | |||
| auto kernel = GetCacheInterface(provider, op_type); | |||
| if (kernel != nullptr) { | |||
| return kernel; | |||
| auto kernel_interface = GetCacheInterface(provider, op_type); | |||
| if (kernel_interface != nullptr) { | |||
| return kernel_interface; | |||
| } | |||
| auto iter = kernel_creators_.find(provider); | |||
| if (iter == kernel_creators_.end()) { | |||
| @@ -142,9 +151,9 @@ std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInter | |||
| auto creator = iter->second[op_type]; | |||
| if (creator != nullptr) { | |||
| kernel = creator(); | |||
| kernel_interfaces_[provider][op_type] = kernel; | |||
| return kernel; | |||
| kernel_interface = creator(); | |||
| kernel_interfaces_[provider][op_type] = kernel_interface; | |||
| return kernel_interface; | |||
| } | |||
| return nullptr; | |||
| } | |||
| @@ -35,7 +35,8 @@ class KernelInterfaceRegistry { | |||
| } | |||
| std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider, | |||
| const schema::Primitive *primitive); | |||
| const schema::Primitive *primitive, | |||
| const kernel::Kernel *kernel); | |||
| Status CustomReg(const std::string &provider, const std::string &op_type, | |||
| const registry::KernelInterfaceCreator creator); | |||
| Status Reg(const std::string &provider, int op_type, const registry::KernelInterfaceCreator creator); | |||
| @@ -46,7 +47,8 @@ class KernelInterfaceRegistry { | |||
| std::shared_ptr<kernel::KernelInterface> GetCacheInterface(const std::string &provider, int op_type); | |||
| std::shared_ptr<kernel::KernelInterface> GetCustomCacheInterface(const std::string &provider, | |||
| const std::string &type); | |||
| std::shared_ptr<kernel::KernelInterface> GetCustomKernelInterface(const schema::Primitive *primitive); | |||
| std::shared_ptr<kernel::KernelInterface> GetCustomKernelInterface(const schema::Primitive *primitive, | |||
| const kernel::Kernel *kernel); | |||
| std::mutex mutex_; | |||
| // key: provider | |||
| @@ -41,10 +41,11 @@ Status RegisterKernelInterface::CustomReg(const std::string &provider, const std | |||
| #endif | |||
| } | |||
| std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface( | |||
| const std::string &provider, const schema::Primitive *primitive) { | |||
| std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(const std::string &provider, | |||
| const schema::Primitive *primitive, | |||
| const kernel::Kernel *kernel) { | |||
| #ifndef CUSTOM_KERNEL_REGISTRY_CLIP | |||
| return KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive); | |||
| return KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive, kernel); | |||
| #else | |||
| MS_LOG(ERROR) << unsupport_custom_kernel_register_log; | |||
| return nullptr; | |||
| @@ -34,23 +34,33 @@ namespace mindspore { | |||
| namespace lite { | |||
| #ifndef CUSTOM_KERNEL_REGISTRY_CLIP | |||
| int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const void *primitive, std::set<std::string> &&providers, int schema_version) { | |||
| if (primitive == nullptr) { | |||
| const void *primitive, std::set<std::string> &&providers, int schema_version, | |||
| const kernel::Kernel *kernel) { | |||
| if (primitive == nullptr && kernel == nullptr) { | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| std::shared_ptr<kernel::KernelInterface> kernel_interface = nullptr; | |||
| if (IsCustomNode(primitive, schema_version)) { | |||
| kernel_interface = | |||
| registry::RegisterKernelInterface::GetKernelInterface("", static_cast<const schema::Primitive *>(primitive)); | |||
| bool is_custom_node = false; | |||
| if (kernel == nullptr) { | |||
| if (IsCustomNode(primitive, schema_version)) { | |||
| is_custom_node = true; | |||
| } | |||
| } else if (kernel->type() == schema::PrimitiveType_Custom) { | |||
| is_custom_node = true; | |||
| } | |||
| if (is_custom_node) { | |||
| kernel_interface = registry::RegisterKernelInterface::GetKernelInterface( | |||
| "", static_cast<const schema::Primitive *>(primitive), kernel); | |||
| } else { | |||
| for (auto &&provider : providers) { | |||
| kernel_interface = registry::RegisterKernelInterface::GetKernelInterface( | |||
| provider, static_cast<const schema::Primitive *>(primitive)); | |||
| provider, static_cast<const schema::Primitive *>(primitive), kernel); | |||
| if (kernel_interface != nullptr) { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (kernel_interface == nullptr) { | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| @@ -26,13 +26,15 @@ | |||
| #include "src/tensor.h" | |||
| #include "nnacl/tensor_c.h" | |||
| #include "nnacl/infer/infer.h" | |||
| #include "include/api/kernel.h" | |||
| namespace mindspore::lite { | |||
| int KernelInferShape(const std::vector<lite::Tensor *> &tensors_in, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *parameter); | |||
| #ifndef CUSTOM_KERNEL_REGISTRY_CLIP | |||
| int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const void *primitive, std::set<std::string> &&providers, int schema_version); | |||
| const void *primitive, std::set<std::string> &&providers, int schema_version, | |||
| const kernel::Kernel *kernel = nullptr); | |||
| #endif | |||
| class InferManager { | |||
| public: | |||
| @@ -428,7 +428,7 @@ int TryFusionConvScaleWeight(LiteKernel *conv_kernel, LiteKernel *scale_kernel) | |||
| MS_ASSERT(conv_kernel); | |||
| MS_ASSERT(scale_kernel); | |||
| auto *scale_param = | |||
| reinterpret_cast<ScaleParameter *>(reinterpret_cast<OpenCLKernel *>(scale_kernel)->GetParameter()); | |||
| reinterpret_cast<ScaleParameter *>(reinterpret_cast<OpenCLKernel *>(scale_kernel->kernel())->GetParameter()); | |||
| MS_ASSERT(scale_param); | |||
| MS_ASSERT(conv_kernel->in_tensors().size() >= INPUT_TENSOR_SIZE_2); | |||
| auto *filter = conv_kernel->in_tensors().at(1); | |||
| @@ -107,7 +107,7 @@ int SubGraphKernel::ReSize() { | |||
| int ret; | |||
| #ifndef CUSTOM_KERNEL_REGISTRY_CLIP | |||
| ret = lite::KernelInferShape(inputs, outputs, kernel->kernel()->primitive(), kernel->Context()->GetProviders(), | |||
| schema_version_); | |||
| schema_version_, kernel->kernel()); | |||
| if (ret == lite::RET_NOT_SUPPORT) { | |||
| #endif | |||
| auto parameter = kernel->op_parameter(); | |||