Browse Source

support get attr from kernel

tags/v1.5.1
chenjianping 4 years ago
parent
commit
5a28bb8687
10 changed files with 122 additions and 43 deletions
  1. +24
    -4
      include/api/kernel.h
  2. +6
    -1
      mindspore/lite/include/registry/register_kernel_interface.h
  3. +30
    -0
      mindspore/lite/src/cxx_api/kernel.cc
  4. +33
    -24
      mindspore/lite/src/registry/kernel_interface_registry.cc
  5. +4
    -2
      mindspore/lite/src/registry/kernel_interface_registry.h
  6. +4
    -3
      mindspore/lite/src/registry/register_kernel_interface.cc
  7. +16
    -6
      mindspore/lite/src/runtime/infer_manager.cc
  8. +3
    -1
      mindspore/lite/src/runtime/infer_manager.h
  9. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc
  10. +1
    -1
      mindspore/lite/src/sub_graph_kernel.cc

+ 24
- 4
include/api/kernel.h View File

@@ -19,13 +19,14 @@
#include <vector>
#include <string>
#include <utility>
#include <map>
#include "schema/model_generated.h"
#include "include/api/types.h"
#include "include/api/context.h"

namespace mindspore::kernel {
/// \brief The Kernel class is used to define a MindSpore Kernel.
class Kernel {
class MS_API Kernel {
public:
Kernel() = default;
/// \brief Constructor.
@@ -37,9 +38,7 @@ class Kernel {
Kernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
const schema::Primitive *primitive, const mindspore::Context *ctx)
: context_(ctx), inputs_(std::move(inputs)), outputs_(std::move(outputs)), primitive_(primitive) {
if (primitive != nullptr) {
type_ = primitive->value_type();
}
Initialize();
}
/// \brief Destructor.
virtual ~Kernel() = default;
@@ -101,6 +100,23 @@ class Kernel {
///
/// \return the primitive of kernel generated by flatbuffers.
const schema::Primitive *primitive() const { return this->primitive_; }
/// \brief get kernel's attribute
///
/// \param[in] key define the kernel's attribute key.
std::string GetAttr(const std::string &key) const {
auto iter = attrs_.find(key);
if (iter != attrs_.end()) {
return iter->second;
}
return "";
}

protected:
/// \brief set kernel's attribute
///
/// \param[in] key define the kernel's attribute key.
/// \param[in] value define the kernel's attribute value.
void SetAttr(const std::string &key, const std::string &value) { attrs_[key] = value; }

protected:
std::string name_;
@@ -109,6 +125,10 @@ class Kernel {
std::vector<mindspore::MSTensor> outputs_;
schema::PrimitiveType type_ = schema::PrimitiveType_NONE;
const schema::Primitive *primitive_ = nullptr;
std::map<std::string, std::string> attrs_;

private:
void Initialize();
};
} // namespace mindspore::kernel



+ 6
- 1
mindspore/lite/include/registry/register_kernel_interface.h View File

@@ -25,6 +25,9 @@
#include "schema/model_generated.h"

namespace mindspore {
namespace kernel {
class Kernel;
}
namespace registry {
/// \brief KernelInterfaceCreator defined a functor to create KernelInterface.
using KernelInterfaceCreator = std::function<std::shared_ptr<kernel::KernelInterface>()>;
@@ -55,10 +58,12 @@ class MS_API RegisterKernelInterface {
///
/// \param[in] provider Define the identification of user.
/// \param[in] primitive Define the attributes of a certain op.
/// \param[in] kernel Define the kernel of a certain op.
///
/// \return Boolean value to represent registration of a certain op is existing or not.
static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive);
const schema::Primitive *primitive,
const kernel::Kernel *kernel = nullptr);
};

/// \brief KernelInterfaceReg defined registration class of KernelInterface.


+ 30
- 0
mindspore/lite/src/cxx_api/kernel.cc View File

@@ -0,0 +1,30 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/api/kernel.h"
namespace mindspore::kernel {
void Kernel::Initialize() {
if (primitive_ == nullptr) {
return;
}
type_ = primitive_->value_type();
if (type_ == schema::PrimitiveType_Custom) {
auto param = primitive_->value_as_Custom();
if (param != nullptr && param->type() != nullptr) {
SetAttr("type", param->type()->str());
}
}
}
} // namespace mindspore::kernel

+ 33
- 24
mindspore/lite/src/registry/kernel_interface_registry.cc View File

@@ -20,6 +20,7 @@
#include "src/common/log_adapter.h"
#include "src/common/version_manager.h"
#include "schema/model_generated.h"
#include "include/api/kernel.h"

using mindspore::registry::KernelInterfaceCreator;
using mindspore::schema::PrimitiveType_MAX;
@@ -32,12 +33,10 @@ static constexpr auto KMaxCustomTypeNum = 200;
static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN + 1;
std::string GetCustomType(const schema::Primitive *primitive) {
auto param = primitive->value_as_Custom();
if (param == nullptr) {
return "";
}
if (param->type() == nullptr) {
if (param == nullptr || param->type() == nullptr) {
return "";
}

return param->type()->str();
}
} // namespace
@@ -92,15 +91,19 @@ std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomCache
}

std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomKernelInterface(
const schema::Primitive *primitive) {
MS_ASSERT(primitive != nullptr);
const schema::Primitive *primitive, const kernel::Kernel *kernel) {
std::unique_lock<std::mutex> lock(mutex_);
auto &&type = GetCustomType(primitive);
std::string type;
if (kernel == nullptr) {
type = GetCustomType(primitive);
} else {
type = kernel->GetAttr("type");
}
for (auto &&item : custom_creators_) {
auto &&provider = item.first;
auto kernel = GetCustomCacheInterface(provider, type);
if (kernel != nullptr) {
return kernel;
auto kernel_interface = GetCustomCacheInterface(provider, type);
if (kernel_interface != nullptr) {
return kernel_interface;
}
auto provider_iter = custom_creators_.find(provider);
if (provider_iter == custom_creators_.end()) {
@@ -108,32 +111,38 @@ std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomKerne
}
auto creator_iter = provider_iter->second.find(type);
if (creator_iter != provider_iter->second.end()) {
kernel = creator_iter->second();
custom_kernels_[provider][type] = kernel;
return kernel;
kernel_interface = creator_iter->second();
custom_kernels_[provider][type] = kernel_interface;
return kernel_interface;
}
}

return nullptr;
}

std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInterface(
const std::string &provider, const schema::Primitive *primitive) {
if (primitive == nullptr) {
std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive,
const kernel::Kernel *kernel) {
if (primitive == nullptr && kernel == nullptr) {
return nullptr;
}
int op_type = static_cast<int>(primitive->value_type());
int op_type;
if (kernel == nullptr) {
op_type = static_cast<int>(primitive->value_type());
} else {
op_type = static_cast<int>(kernel->type());
}
if (op_type > PrimitiveType_MAX || op_type <= PrimitiveType_MIN) {
return nullptr;
}
if (op_type == schema::PrimitiveType_Custom) {
return GetCustomKernelInterface(primitive);
return GetCustomKernelInterface(primitive, kernel);
}

std::unique_lock<std::mutex> lock(mutex_);
auto kernel = GetCacheInterface(provider, op_type);
if (kernel != nullptr) {
return kernel;
auto kernel_interface = GetCacheInterface(provider, op_type);
if (kernel_interface != nullptr) {
return kernel_interface;
}
auto iter = kernel_creators_.find(provider);
if (iter == kernel_creators_.end()) {
@@ -142,9 +151,9 @@ std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInter

auto creator = iter->second[op_type];
if (creator != nullptr) {
kernel = creator();
kernel_interfaces_[provider][op_type] = kernel;
return kernel;
kernel_interface = creator();
kernel_interfaces_[provider][op_type] = kernel_interface;
return kernel_interface;
}
return nullptr;
}


+ 4
- 2
mindspore/lite/src/registry/kernel_interface_registry.h View File

@@ -35,7 +35,8 @@ class KernelInterfaceRegistry {
}

std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive);
const schema::Primitive *primitive,
const kernel::Kernel *kernel);
Status CustomReg(const std::string &provider, const std::string &op_type,
const registry::KernelInterfaceCreator creator);
Status Reg(const std::string &provider, int op_type, const registry::KernelInterfaceCreator creator);
@@ -46,7 +47,8 @@ class KernelInterfaceRegistry {
std::shared_ptr<kernel::KernelInterface> GetCacheInterface(const std::string &provider, int op_type);
std::shared_ptr<kernel::KernelInterface> GetCustomCacheInterface(const std::string &provider,
const std::string &type);
std::shared_ptr<kernel::KernelInterface> GetCustomKernelInterface(const schema::Primitive *primitive);
std::shared_ptr<kernel::KernelInterface> GetCustomKernelInterface(const schema::Primitive *primitive,
const kernel::Kernel *kernel);

std::mutex mutex_;
// key: provider


+ 4
- 3
mindspore/lite/src/registry/register_kernel_interface.cc View File

@@ -41,10 +41,11 @@ Status RegisterKernelInterface::CustomReg(const std::string &provider, const std
#endif
}

std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
const std::string &provider, const schema::Primitive *primitive) {
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive,
const kernel::Kernel *kernel) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive);
return KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive, kernel);
#else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return nullptr;


+ 16
- 6
mindspore/lite/src/runtime/infer_manager.cc View File

@@ -34,23 +34,33 @@ namespace mindspore {
namespace lite {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const void *primitive, std::set<std::string> &&providers, int schema_version) {
if (primitive == nullptr) {
const void *primitive, std::set<std::string> &&providers, int schema_version,
const kernel::Kernel *kernel) {
if (primitive == nullptr && kernel == nullptr) {
return RET_NOT_SUPPORT;
}
std::shared_ptr<kernel::KernelInterface> kernel_interface = nullptr;
if (IsCustomNode(primitive, schema_version)) {
kernel_interface =
registry::RegisterKernelInterface::GetKernelInterface("", static_cast<const schema::Primitive *>(primitive));
bool is_custom_node = false;
if (kernel == nullptr) {
if (IsCustomNode(primitive, schema_version)) {
is_custom_node = true;
}
} else if (kernel->type() == schema::PrimitiveType_Custom) {
is_custom_node = true;
}
if (is_custom_node) {
kernel_interface = registry::RegisterKernelInterface::GetKernelInterface(
"", static_cast<const schema::Primitive *>(primitive), kernel);
} else {
for (auto &&provider : providers) {
kernel_interface = registry::RegisterKernelInterface::GetKernelInterface(
provider, static_cast<const schema::Primitive *>(primitive));
provider, static_cast<const schema::Primitive *>(primitive), kernel);
if (kernel_interface != nullptr) {
break;
}
}
}

if (kernel_interface == nullptr) {
return RET_NOT_SUPPORT;
}


+ 3
- 1
mindspore/lite/src/runtime/infer_manager.h View File

@@ -26,13 +26,15 @@
#include "src/tensor.h"
#include "nnacl/tensor_c.h"
#include "nnacl/infer/infer.h"
#include "include/api/kernel.h"

namespace mindspore::lite {
int KernelInferShape(const std::vector<lite::Tensor *> &tensors_in, const std::vector<lite::Tensor *> &outputs,
OpParameter *parameter);
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const void *primitive, std::set<std::string> &&providers, int schema_version);
const void *primitive, std::set<std::string> &&providers, int schema_version,
const kernel::Kernel *kernel = nullptr);
#endif
class InferManager {
public:


+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc View File

@@ -428,7 +428,7 @@ int TryFusionConvScaleWeight(LiteKernel *conv_kernel, LiteKernel *scale_kernel)
MS_ASSERT(conv_kernel);
MS_ASSERT(scale_kernel);
auto *scale_param =
reinterpret_cast<ScaleParameter *>(reinterpret_cast<OpenCLKernel *>(scale_kernel)->GetParameter());
reinterpret_cast<ScaleParameter *>(reinterpret_cast<OpenCLKernel *>(scale_kernel->kernel())->GetParameter());
MS_ASSERT(scale_param);
MS_ASSERT(conv_kernel->in_tensors().size() >= INPUT_TENSOR_SIZE_2);
auto *filter = conv_kernel->in_tensors().at(1);


+ 1
- 1
mindspore/lite/src/sub_graph_kernel.cc View File

@@ -107,7 +107,7 @@ int SubGraphKernel::ReSize() {
int ret;
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
ret = lite::KernelInferShape(inputs, outputs, kernel->kernel()->primitive(), kernel->Context()->GetProviders(),
schema_version_);
schema_version_, kernel->kernel());
if (ret == lite::RET_NOT_SUPPORT) {
#endif
auto parameter = kernel->op_parameter();


Loading…
Cancel
Save