Browse Source

support get config and attr from kernel

tags/v1.6.0
chenjianping 4 years ago
parent
commit
8acc951a47
28 changed files with 405 additions and 161 deletions
  1. +46
    -4
      include/api/kernel.h
  2. +15
    -0
      include/api/model.h
  3. +8
    -8
      mindspore/lite/include/registry/register_kernel.h
  4. +22
    -13
      mindspore/lite/include/registry/register_kernel_interface.h
  5. +56
    -39
      mindspore/lite/src/common/config_file.cc
  6. +2
    -4
      mindspore/lite/src/common/config_file.h
  7. +30
    -0
      mindspore/lite/src/cxx_api/kernel.cc
  8. +13
    -0
      mindspore/lite/src/cxx_api/model/model.cc
  9. +34
    -5
      mindspore/lite/src/cxx_api/model/model_impl.cc
  10. +2
    -0
      mindspore/lite/src/cxx_api/model/model_impl.h
  11. +1
    -0
      mindspore/lite/src/lite_session.cc
  12. +5
    -0
      mindspore/lite/src/lite_session.h
  13. +3
    -3
      mindspore/lite/src/ops/populate/affine_populate.cc
  14. +2
    -2
      mindspore/lite/src/ops/populate/control/tensor_array_populate.cc
  15. +6
    -6
      mindspore/lite/src/ops/populate/splice_populate.cc
  16. +59
    -26
      mindspore/lite/src/registry/kernel_interface_registry.cc
  17. +7
    -4
      mindspore/lite/src/registry/kernel_interface_registry.h
  18. +2
    -2
      mindspore/lite/src/registry/register_kernel.cc
  19. +52
    -25
      mindspore/lite/src/registry/register_kernel_impl.cc
  20. +2
    -3
      mindspore/lite/src/registry/register_kernel_impl.h
  21. +7
    -5
      mindspore/lite/src/registry/register_kernel_interface.cc
  22. +16
    -6
      mindspore/lite/src/runtime/infer_manager.cc
  23. +3
    -1
      mindspore/lite/src/runtime/infer_manager.h
  24. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc
  25. +3
    -0
      mindspore/lite/src/scheduler.cc
  26. +4
    -0
      mindspore/lite/src/scheduler.h
  27. +1
    -1
      mindspore/lite/src/sub_graph_kernel.cc
  28. +3
    -3
      mindspore/lite/test/st/mix_data_type_test.cc

+ 46
- 4
include/api/kernel.h View File

@@ -19,13 +19,14 @@
#include <vector>
#include <string>
#include <utility>
#include <map>
#include "schema/model_generated.h"
#include "include/api/types.h"
#include "include/api/context.h"

namespace mindspore::kernel {
/// \brief The Kernel class is used to define a MindSpore Kernel.
class Kernel {
class MS_API Kernel {
public:
Kernel() = default;
/// \brief Constructor.
@@ -37,9 +38,7 @@ class Kernel {
Kernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
const schema::Primitive *primitive, const mindspore::Context *ctx)
: context_(ctx), inputs_(std::move(inputs)), outputs_(std::move(outputs)), primitive_(primitive) {
if (primitive != nullptr) {
type_ = primitive->value_type();
}
Initialize();
}
/// \brief Destructor.
virtual ~Kernel() = default;
@@ -102,6 +101,44 @@ class Kernel {
/// \return the primitive of kernel generated by flatbuffers.
const schema::Primitive *primitive() const { return this->primitive_; }

/// \brief get kernel's attribute.
///
/// \param[in] key define the kernel's attribute key.
std::string GetAttr(const std::string &key) const {
auto iter = attrs_.find(key);
if (iter != attrs_.end()) {
return iter->second;
}
return "";
}

/// \brief set kernel's config.
///
/// \param[in] config define the kernel's config.
void SetConfig(const std::map<std::string, std::map<std::string, std::string>> *config) {
config_ = config;
}
/// \brief set kernel's config.
///
/// \param[in] config define the kernel's config.
std::map<std::string, std::string> GetConfig(const std::string &section) const {
if (config_ == nullptr) {
return std::map<std::string, std::string>();
}
auto iter = config_->find(section);
if (iter != config_->end()) {
return iter->second;
}
return std::map<std::string, std::string>();
}

protected:
/// \brief set kernel's attribute
///
/// \param[in] key define the kernel's attribute key.
/// \param[in] value define the kernel's attribute value.
void SetAttr(const std::string &key, const std::string &value) { attrs_[key] = value; }

protected:
std::string name_;
const mindspore::Context *context_ = nullptr;
@@ -109,6 +146,11 @@ class Kernel {
std::vector<mindspore::MSTensor> outputs_;
schema::PrimitiveType type_ = schema::PrimitiveType_NONE;
const schema::Primitive *primitive_ = nullptr;
std::map<std::string, std::string> attrs_;
const std::map<std::string, std::map<std::string, std::string>> *config_;

private:
void Initialize();
};
} // namespace mindspore::kernel



+ 15
- 0
include/api/model.h View File

@@ -106,6 +106,14 @@ class MS_API Model {
/// \return Status.
inline Status LoadConfig(const std::string &config_path);

/// \brief Update config.
///
/// \param[in] section define the config section.
/// \param[in] config define the config will be updated.
///
/// \return Status.
inline Status UpdateConfig(const std::string &section, const std::pair<std::string, std::string> &config);

/// \brief Obtains all input tensors of the model.
///
/// \return The vector that includes all input tensors.
@@ -215,6 +223,7 @@ class MS_API Model {
MSTensor GetOutputByTensorName(const std::vector<char> &tensor_name);
std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
Status LoadConfig(const std::vector<char> &config_path);
Status UpdateConfig(const std::vector<char> &section, const std::pair<std::vector<char>, std::vector<char>> &config);
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::vector<char> &dec_mode);
Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
@@ -241,6 +250,12 @@ Status Model::LoadConfig(const std::string &config_path) {
return LoadConfig(StringToChar(config_path));
}

Status Model::UpdateConfig(const std::string &section, const std::pair<std::string, std::string> &config) {
std::pair<std::vector<char>, std::vector<char>> config_pair = {StringToChar(config.first),
StringToChar(config.second)};
return UpdateConfig(StringToChar(section), config_pair);
}

Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
return Build(model_data, data_size, model_type, model_context, dec_key, StringToChar(dec_mode));


+ 8
- 8
mindspore/lite/include/registry/register_kernel.h View File

@@ -71,7 +71,7 @@ class MS_API RegisterKernel {
///
/// \return Status as a status identification of registering.
inline static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
CreateKernel creator);
const CreateKernel creator);

/// \brief Static method to register kernel which is corresponding to custom op.
///
@@ -83,7 +83,7 @@ class MS_API RegisterKernel {
///
/// \return Status as a status identification of registering.
inline static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator);
const std::string &type, const CreateKernel creator);

/// \brief Static methon to get a kernel's create function.
///
@@ -95,9 +95,9 @@ class MS_API RegisterKernel {

private:
static Status RegKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
int type, CreateKernel creator);
int type, const CreateKernel creator);
static Status RegCustomKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
const std::vector<char> &type, CreateKernel creator);
const std::vector<char> &type, const CreateKernel creator);
static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDescHelper *desc);
};

@@ -115,7 +115,7 @@ class MS_API KernelReg {
/// \param[in] op_type Define the ordinary op type.
/// \param[in] creator Define a function pointer to create a kernel.
KernelReg(const std::string &arch, const std::string &provider, DataType data_type, int op_type,
CreateKernel creator) {
const CreateKernel creator) {
RegisterKernel::RegKernel(arch, provider, data_type, op_type, creator);
}

@@ -127,18 +127,18 @@ class MS_API KernelReg {
/// \param[in] op_type Define the concrete type of a custom op.
/// \param[in] creator Define a function pointer to create a kernel.
KernelReg(const std::string &arch, const std::string &provider, DataType data_type, const std::string &op_type,
CreateKernel creator) {
const CreateKernel creator) {
RegisterKernel::RegCustomKernel(arch, provider, data_type, op_type, creator);
}
};

Status RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
CreateKernel creator) {
const CreateKernel creator) {
return RegKernel(StringToChar(arch), StringToChar(provider), data_type, type, creator);
}

Status RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator) {
const std::string &type, const CreateKernel creator) {
return RegCustomKernel(StringToChar(arch), StringToChar(provider), data_type, StringToChar(type), creator);
}



+ 22
- 13
mindspore/lite/include/registry/register_kernel_interface.h View File

@@ -25,6 +25,9 @@
#include "schema/model_generated.h"

namespace mindspore {
namespace kernel {
class Kernel;
}
namespace registry {
/// \brief KernelInterfaceCreator defined a functor to create KernelInterface.
using KernelInterfaceCreator = std::function<std::shared_ptr<kernel::KernelInterface>()>;
@@ -40,7 +43,7 @@ class MS_API RegisterKernelInterface {
///
/// \return Status as a status identification of registering.
inline static Status CustomReg(const std::string &provider, const std::string &op_type,
KernelInterfaceCreator creator);
const KernelInterfaceCreator creator);

/// \brief Static method to register op whose primitive type is ordinary.
///
@@ -49,23 +52,26 @@ class MS_API RegisterKernelInterface {
/// \param[in] creator Define the KernelInterface create function.
///
/// \return Status as a status identification of registering.
inline static Status Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
inline static Status Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator);

/// \brief Static method to get registration of a certain op.
///
/// \param[in] provider Define the identification of user.
/// \param[in] primitive Define the attributes of a certain op.
/// \param[in] kernel Define the kernel of a certain op.
///
/// \return Boolean value to represent registration of a certain op is existing or not.
inline static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive);
const schema::Primitive *primitive,
const kernel::Kernel *kernel = nullptr);

private:
static Status CustomReg(const std::vector<char> &provider, const std::vector<char> &op_type,
KernelInterfaceCreator creator);
static Status Reg(const std::vector<char> &provider, int op_type, KernelInterfaceCreator creator);
const KernelInterfaceCreator creator);
static Status Reg(const std::vector<char> &provider, int op_type, const KernelInterfaceCreator creator);
static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::vector<char> &provider,
const schema::Primitive *primitive);
const schema::Primitive *primitive,
const kernel::Kernel *kernel = nullptr);
};

/// \brief KernelInterfaceReg defined registration class of KernelInterface.
@@ -76,7 +82,7 @@ class MS_API KernelInterfaceReg {
/// \param[in] provider Define the identification of user.
/// \param[in] op_type Define the ordinary op type.
/// \param[in] creator Define the KernelInterface create function.
KernelInterfaceReg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
KernelInterfaceReg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
RegisterKernelInterface::Reg(provider, op_type, creator);
}

@@ -85,23 +91,26 @@ class MS_API KernelInterfaceReg {
/// \param[in] provider Define the identification of user.
/// \param[in] op_type Define the concrete type of a custom op.
/// \param[in] creator Define the KernelInterface create function.
KernelInterfaceReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator) {
KernelInterfaceReg(const std::string &provider, const std::string &op_type, const KernelInterfaceCreator creator) {
RegisterKernelInterface::CustomReg(provider, op_type, creator);
}

virtual ~KernelInterfaceReg() = default;
};

Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
KernelInterfaceCreator creator) {
const KernelInterfaceCreator creator) {
return CustomReg(StringToChar(provider), StringToChar(op_type), creator);
}

Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
return Reg(StringToChar(provider), op_type, creator);
}

std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
const std::string &provider, const schema::Primitive *primitive) {
return GetKernelInterface(StringToChar(provider), primitive);
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive,
const kernel::Kernel *kernel) {
return GetKernelInterface(StringToChar(provider), primitive, kernel);
}

/// \brief Defined registering macro to register ordinary op, which called by user directly.


+ 56
- 39
mindspore/lite/src/common/config_file.cc View File

@@ -21,19 +21,55 @@
#endif
namespace {
constexpr size_t kLengthOfParentheses = 2;
}
constexpr size_t kMinSectionLineLength = 2;
constexpr size_t kMaxValidLineCount = 100000;
constexpr size_t kMaxLineCount = 100100;

} // namespace

namespace mindspore {
namespace lite {
int GetSectionInfoFromConfigFile(const std::string &file, const std::string &section_name,
std::map<std::string, std::string> *section_info) {
if (file.empty()) {
MS_LOG(ERROR) << "file is nullptr";
namespace {
void ParseLine(const std::string &line, std::map<std::string, std::string> *section_config, std::string *section,
size_t *valid_line_count, std::map<std::string, std::map<std::string, std::string>> *config) {
// eg: [section]
// key=value
if (line[0] == '[' && line[line.length() - 1] == ']') {
if (!section->empty() && !section_config->empty()) {
config->insert(std::make_pair(*section, *section_config));
}
section_config->clear();
*section = line.substr(1, line.length() - kLengthOfParentheses);
*valid_line_count = *valid_line_count + 1;
}

if (!section->empty()) {
auto index = line.find('=');
if (index == std::string::npos) {
return;
}
auto key = line.substr(0, index);
if (index + 1 > line.size()) {
return;
}
auto value = line.substr(index + 1);
lite::Trim(&key);
lite::Trim(&value);
section_config->insert(std::make_pair(key, value));
*valid_line_count = *valid_line_count + 1;
}
}
} // namespace

int GetAllSectionInfoFromConfigFile(const std::string &file,
std::map<std::string, std::map<std::string, std::string>> *config) {
if (file.empty() || config == nullptr) {
MS_LOG(ERROR) << "input Invalid!check file and config.";
return RET_ERROR;
}
auto resolved_path = std::make_unique<char[]>(PATH_MAX);
if (resolved_path == nullptr) {
MS_LOG(ERROR) << "new resolved_path failed";
MS_LOG(ERROR) << "new resolved_path fail!";
return RET_ERROR;
}

@@ -56,44 +92,25 @@ int GetSectionInfoFromConfigFile(const std::string &file, const std::string &sec
return RET_ERROR;
}
std::string line;

bool find_section = false;
std::string section;
std::map<std::string, std::string> section_config;
size_t line_count = 0;
size_t valid_line_count = 0;
while (std::getline(ifs, line)) {
lite::Trim(&line);
if (line.empty()) {
continue;
line_count++;
if (line_count >= kMaxLineCount || valid_line_count >= kMaxValidLineCount) {
MS_LOG(ERROR) << "config too many lines!";
return RET_ERROR;
}
if (line[0] == '#') {
lite::Trim(&line);
if (line.length() <= kMinSectionLineLength || line[0] == '#') {
continue;
}

if (line[0] == '[') {
if (find_section == true) {
break;
}
std::string section = line.substr(1, line.length() - kLengthOfParentheses);
if (section != section_name) {
continue;
}
find_section = true;
}

if (find_section == true) {
auto index = line.find('=');
if (index == std::string::npos) {
continue;
}
auto key = line.substr(0, index);
if (index + 1 > line.size()) {
return RET_ERROR;
}
auto value = line.substr(index + 1);
lite::Trim(&key);
lite::Trim(&value);
section_info->insert(std::make_pair(key, value));
}
ParseLine(line, &section_config, &section, &valid_line_count, config);
}
if (!section.empty() && !section_config.empty()) {
config->insert(std::make_pair(section, section_config));
}

ifs.close();
return RET_OK;
}


+ 2
- 4
mindspore/lite/src/common/config_file.h View File

@@ -35,10 +35,8 @@
namespace mindspore {
namespace lite {
constexpr int MAX_CONFIG_FILE_LENGTH = 1024;
#define CONFIG_FILE_EXECUTION_PLAN "execution_plan"

int GetSectionInfoFromConfigFile(const std::string &file, const std::string &section_name,
std::map<std::string, std::string> *section_info);
int GetAllSectionInfoFromConfigFile(const std::string &file,
std::map<std::string, std::map<std::string, std::string>> *config);

void ParserExecutionPlan(const std::map<std::string, std::string> *config_infos,
std::map<std::string, TypeId> *data_type_plan);


+ 30
- 0
mindspore/lite/src/cxx_api/kernel.cc View File

@@ -0,0 +1,30 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/api/kernel.h"
namespace mindspore::kernel {
void Kernel::Initialize() {
if (primitive_ == nullptr) {
return;
}
type_ = primitive_->value_type();
if (type_ == schema::PrimitiveType_Custom) {
auto param = primitive_->value_as_Custom();
if (param != nullptr && param->type() != nullptr) {
SetAttr("type", param->type()->str());
}
}
}
} // namespace mindspore::kernel

+ 13
- 0
mindspore/lite/src/cxx_api/model/model.cc View File

@@ -209,6 +209,19 @@ Status Model::LoadConfig(const std::vector<char> &config_path) {
return kSuccess;
}

Status Model::UpdateConfig(const std::vector<char> &section,
const std::pair<std::vector<char>, std::vector<char>> &config) {
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
if (impl_ == nullptr) {
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
}
if (impl_ != nullptr) {
return impl_->UpdateConfig(CharToString(section), {CharToString(config.first), CharToString(config.second)});
}
MS_LOG(ERROR) << "Model implement is null!";
return kLiteFileError;
}

Status Model::SetTrainMode(bool train) {
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
MS_LOG(ERROR) << "Model is null.";


+ 34
- 5
mindspore/lite/src/cxx_api/model/model_impl.cc View File

@@ -17,6 +17,10 @@
#include "src/cxx_api/model/model_impl.h"
#include <memory>
#include <algorithm>
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "include/api/types.h"
#include "include/api/context.h"
#include "include/lite_session.h"
@@ -32,6 +36,11 @@
#include "src/common/config_file.h"

namespace mindspore {
namespace {
static const char *kExecutionPlan = "execution_plan";
static constexpr size_t kMaxSectionNum = 100;
static constexpr size_t kMaxConfigNumPerSection = 1000;
} // namespace
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;

@@ -195,15 +204,16 @@ Status ModelImpl::RunGraph(const MSKernelCallBack &before, const MSKernelCallBac
bool ModelImpl::IsTrainModel() { return (graph_ && graph_->graph_data_ && graph_->graph_data_->IsTrainModel()); }

Status ModelImpl::LoadConfig(const std::string &config_path) {
std::map<std::string, std::string> config_info;
int ret = lite::GetSectionInfoFromConfigFile(config_path, CONFIG_FILE_EXECUTION_PLAN, &config_info);
std::map<std::string, std::map<std::string, std::string>> all_config_info;
int ret = lite::GetAllSectionInfoFromConfigFile(config_path, &all_config_info);
if (ret != RET_OK) {
MS_LOG(ERROR) << "GetSectionInfoFromConfigFile failed.";
MS_LOG(ERROR) << "GetAllSectionInfoFromConfigFile fail!ret: " << ret;
return kLiteFileError;
}

config_info_ = all_config_info;
std::map<std::string, std::string> config_info = all_config_info[kExecutionPlan];
if (config_info.empty()) {
MS_LOG(WARNING) << "No valid info in config file.";
MS_LOG(WARNING) << "No valid execution plan info in config file.";
return kSuccess;
}

@@ -211,6 +221,24 @@ Status ModelImpl::LoadConfig(const std::string &config_path) {
return kSuccess;
}

Status ModelImpl::UpdateConfig(const std::string &section, const std::pair<std::string, std::string> &config) {
auto iter = config_info_.find(section);
if (iter == config_info_.end()) {
if (config_info_.size() >= kMaxSectionNum) {
MS_LOG(ERROR) << "config too many sections!";
return kLiteError;
}
config_info_[section][config.first] = config.second;
return kSuccess;
}
if (iter->second.size() >= kMaxConfigNumPerSection) {
MS_LOG(ERROR) << "config too many items!";
return kLiteError;
}
iter->second[config.first] = config.second;
return kSuccess;
}

Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
if (outputs == nullptr) {
@@ -567,6 +595,7 @@ session::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context)
}

session->InitExecutionConfig(&execution_plan_);
session->SetConfigInfo(&config_info_);

auto ret = session->Init(context);
if (ret != mindspore::lite::RET_OK) {


+ 2
- 0
mindspore/lite/src/cxx_api/model/model_impl.h View File

@@ -70,6 +70,7 @@ class ModelImpl {
session::LiteSession *CreateLiteSession(lite::InnerContext *context);

Status LoadConfig(const std::string &config_path);
Status UpdateConfig(const std::string &section, const std::pair<std::string, std::string> &config);
std::vector<MSTensor> GetInputs();
std::vector<MSTensor> GetOutputs();
std::vector<MSTensor> GetGradients() const;
@@ -112,6 +113,7 @@ class ModelImpl {
void SetConfig(const std::shared_ptr<TrainCfg> cfg) { cfg_ = cfg; }
Status RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after);
std::map<std::string, TypeId> execution_plan_;
std::map<std::string, std::map<std::string, std::string>> config_info_;
};
} // namespace mindspore



+ 1
- 0
mindspore/lite/src/lite_session.cc View File

@@ -523,6 +523,7 @@ int LiteSession::CompileGraph(Model *model) {
Scheduler scheduler(context_, ms_context_, model, &tensors_, inputs_, outputs_, is_train_session_, &is_infershape_,
&is_control_flow_, execution_plan_, delegate_, delegate_device_type_);
scheduler.SetupSchedulerCb(std::move(sched_cb_));
scheduler.SetConfig(config_info_);
ret = scheduler.Schedule(&kernels_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Schedule kernels failed: " << ret;


+ 5
- 0
mindspore/lite/src/lite_session.h View File

@@ -87,6 +87,10 @@ class LiteSession : public session::LiteSession {

const Delegate *get_delegate() const { return this->delegate_.get(); }

void SetConfigInfo(const std::map<std::string, std::map<std::string, std::string>> *config_info) {
config_info_ = config_info;
}

protected:
static void ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lite::Tensor *dst_tensor);

@@ -182,6 +186,7 @@ class LiteSession : public session::LiteSession {
std::shared_ptr<Delegate> delegate_ = nullptr;
int delegate_device_type_ = -1; // -1: not specified; 0: CPU; 1: GPU; 2: NPU
std::map<std::string, TypeId> *execution_plan_ = nullptr;
const std::map<std::string, std::map<std::string, std::string>> *config_info_ = nullptr;
};
} // namespace lite
} // namespace mindspore


+ 3
- 3
mindspore/lite/src/ops/populate/affine_populate.cc View File

@@ -70,14 +70,14 @@ OpParameter *PopulateAffineParameter(const void *prim) {
affine_param->context_size_ = static_cast<int>(context.size());

// malloc && memset for context
affine_param->context_ = reinterpret_cast<int *>(malloc(affine_param->context_size_ * sizeof(int)));
affine_param->context_ = reinterpret_cast<int *>(malloc(context.size() * sizeof(int)));
if (affine_param->context_ == nullptr) {
MS_LOG(ERROR) << "malloc param context_ for affine layer failed!";
ReleaseParam(affine_param, matmul_param);
return nullptr;
}
memset(affine_param->context_, 0, affine_param->context_size_ * sizeof(int));
for (int i = 0; i < affine_param->context_size_; ++i) {
(void)memset(affine_param->context_, 0, context.size() * sizeof(int));
for (size_t i = 0; i < context.size(); ++i) {
affine_param->context_[i] = context.at(i);
}
affine_param->output_dim_ = value->output_dim();


+ 2
- 2
mindspore/lite/src/ops/populate/control/tensor_array_populate.cc View File

@@ -43,8 +43,8 @@ OpParameter *PopulateTensorArrayParameter(const void *prim) {
bool identical_element_shapes = value->identical_element_shapes();
param->identical_element_shapes_ = identical_element_shapes;
std::vector<int> primitive_element_shape(value->element_shape()->begin(), value->element_shape()->end());
param->element_shape_size_ = primitive_element_shape.size();
int size = sizeof(int) * param->element_shape_size_;
param->element_shape_size_ = static_cast<int>(primitive_element_shape.size());
auto size = sizeof(int) * param->element_shape_size_;
param->element_shape_ = static_cast<int *>(malloc(size));
if (param->element_shape_ == nullptr) {
MS_LOG(ERROR) << "malloc element_shape failed!";


+ 6
- 6
mindspore/lite/src/ops/populate/splice_populate.cc View File

@@ -52,7 +52,7 @@ OpParameter *PopulateSpliceParameter(const void *prim) {
param->context_dim_ = static_cast<int>(primitive_context.size());

// malloc && memset for context
param->context_ = reinterpret_cast<int *>(malloc(param->context_dim_ * sizeof(int)));
param->context_ = reinterpret_cast<int *>(malloc(primitive_context.size() * sizeof(int)));
if (param->context_ == nullptr) {
MS_LOG(ERROR) << "malloc param context_ error";
free(param);
@@ -60,8 +60,8 @@ OpParameter *PopulateSpliceParameter(const void *prim) {
}
// src_to_dst_row_offset
int src_to_dst_row_offset = INT32_MIN;
memset(param->context_, 0, param->context_dim_ * sizeof(int));
for (int i = 0; i < param->context_dim_; ++i) {
(void)memset(param->context_, 0, primitive_context.size() * sizeof(int));
for (size_t i = 0; i < primitive_context.size(); ++i) {
param->context_[i] = primitive_context[i];
src_to_dst_row_offset = std::max(src_to_dst_row_offset, std::abs(primitive_context.at(i)));
}
@@ -83,15 +83,15 @@ OpParameter *PopulateSpliceParameter(const void *prim) {
param->forward_indexes_dim_ = static_cast<int>(primitive_forward_indexes.size());

// malloc && memset for forward_indexes
param->forward_indexes_ = reinterpret_cast<int *>(malloc(param->forward_indexes_dim_ * sizeof(int)));
param->forward_indexes_ = reinterpret_cast<int *>(malloc(primitive_context.size() * sizeof(int)));
if (param->forward_indexes_ == nullptr) {
MS_LOG(ERROR) << "malloc param forward_indexes_ error";
free(param->context_);
free(param);
return nullptr;
}
memset(param->forward_indexes_, 0, param->forward_indexes_dim_ * sizeof(int));
memcpy(param->forward_indexes_, primitive_forward_indexes.data(), param->forward_indexes_dim_ * sizeof(int));
(void)memset(param->forward_indexes_, 0, primitive_context.size() * sizeof(int));
(void)memcpy(param->forward_indexes_, primitive_forward_indexes.data(), primitive_context.size() * sizeof(int));
param->output_dim_ = value->output_dim();
return reinterpret_cast<OpParameter *>(param);
}


+ 59
- 26
mindspore/lite/src/registry/kernel_interface_registry.cc View File

@@ -20,6 +20,7 @@
#include "src/common/log_adapter.h"
#include "src/common/version_manager.h"
#include "schema/model_generated.h"
#include "include/api/kernel.h"

using mindspore::registry::KernelInterfaceCreator;
using mindspore::schema::PrimitiveType_MAX;
@@ -27,16 +28,33 @@ using mindspore::schema::PrimitiveType_MIN;
namespace mindspore {
namespace registry {
namespace {
static constexpr auto kMaxProviderNum = 10;
static constexpr auto KMaxCustomTypeNum = 200;
static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN + 1;
std::string GetCustomType(const schema::Primitive *primitive) {
auto param = primitive->value_as_Custom();
MS_ASSERT(param != nullptr);
if (param == nullptr || param->type() == nullptr) {
return "";
}

return param->type()->str();
}
} // namespace

Status KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &type,
KernelInterfaceCreator creator) {
const KernelInterfaceCreator creator) {
auto provider_iter = custom_creators_.find(provider);
if (provider_iter == custom_creators_.end() && custom_creators_.size() >= kMaxProviderNum) {
MS_LOG(ERROR) << "register too many provider!";
return kLiteError;
}
if (provider_iter != custom_creators_.end()) {
auto type_iter = provider_iter->second.find(type);
if (type_iter == provider_iter->second.end() && provider_iter->second.size() >= KMaxCustomTypeNum) {
MS_LOG(ERROR) << "register too many custom type!";
return kLiteError;
}
}
custom_creators_[provider][type] = creator;
return kSuccess;
}
@@ -73,15 +91,19 @@ std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomCache
}

std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomKernelInterface(
const schema::Primitive *primitive) {
MS_ASSERT(primitive != nullptr);
const schema::Primitive *primitive, const kernel::Kernel *kernel) {
std::unique_lock<std::mutex> lock(mutex_);
auto &&type = GetCustomType(primitive);
std::string type;
if (kernel == nullptr) {
type = GetCustomType(primitive);
} else {
type = kernel->GetAttr("type");
}
for (auto &&item : custom_creators_) {
auto &&provider = item.first;
auto kernel = GetCustomCacheInterface(provider, type);
if (kernel != nullptr) {
return kernel;
auto kernel_interface = GetCustomCacheInterface(provider, type);
if (kernel_interface != nullptr) {
return kernel_interface;
}
auto provider_iter = custom_creators_.find(provider);
if (provider_iter == custom_creators_.end()) {
@@ -89,47 +111,54 @@ std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomKerne
}
auto creator_iter = provider_iter->second.find(type);
if (creator_iter != provider_iter->second.end()) {
kernel = creator_iter->second();
custom_kernels_[provider][type] = kernel;
return kernel;
kernel_interface = creator_iter->second();
custom_kernels_[provider][type] = kernel_interface;
return kernel_interface;
}
}

return nullptr;
}

std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInterface(
const std::string &provider, const schema::Primitive *primitive) {
if (primitive == nullptr) {
std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive,
const kernel::Kernel *kernel) {
if (primitive == nullptr && kernel == nullptr) {
return nullptr;
}
int op_type;
if (kernel == nullptr) {
op_type = static_cast<int>(primitive->value_type());
} else {
op_type = static_cast<int>(kernel->type());
}
if (op_type > PrimitiveType_MAX || op_type <= PrimitiveType_MIN) {
return nullptr;
}
int op_type = primitive->value_type();
if (op_type == schema::PrimitiveType_Custom) {
return GetCustomKernelInterface(primitive);
return GetCustomKernelInterface(primitive, kernel);
}

std::unique_lock<std::mutex> lock(mutex_);
auto kernel = GetCacheInterface(provider, op_type);
if (kernel != nullptr) {
return kernel;
auto kernel_interface = GetCacheInterface(provider, op_type);
if (kernel_interface != nullptr) {
return kernel_interface;
}
auto iter = kernel_creators_.find(provider);
if (iter == kernel_creators_.end()) {
return nullptr;
}
if (op_type > PrimitiveType_MAX || op_type <= PrimitiveType_MIN) {
return nullptr;
}

auto creator = iter->second[op_type];
if (creator != nullptr) {
kernel = creator();
kernel_interfaces_[provider][op_type] = kernel;
return kernel;
kernel_interface = creator();
kernel_interfaces_[provider][op_type] = kernel_interface;
return kernel_interface;
}
return nullptr;
}

Status KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
Status KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
if (op_type <= PrimitiveType_MIN || op_type > PrimitiveType_MAX) {
MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << PrimitiveType_MAX;
return kLiteParamInvalid;
@@ -142,6 +171,10 @@ Status KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, Ke
std::unique_lock<std::mutex> lock(mutex_);
auto iter = kernel_creators_.find(provider);
if (iter == kernel_creators_.end()) {
if (kernel_creators_.size() >= kMaxProviderNum) {
MS_LOG(ERROR) << "register too many provider!";
return kLiteError;
}
kernel_creators_[provider] =
reinterpret_cast<KernelInterfaceCreator *>(calloc(kMaxKernelNum, sizeof(KernelInterfaceCreator)));
if (kernel_creators_[provider] == nullptr) {


+ 7
- 4
mindspore/lite/src/registry/kernel_interface_registry.h View File

@@ -35,9 +35,11 @@ class KernelInterfaceRegistry {
}

std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive);
Status CustomReg(const std::string &provider, const std::string &op_type, registry::KernelInterfaceCreator creator);
Status Reg(const std::string &provider, int op_type, registry::KernelInterfaceCreator creator);
const schema::Primitive *primitive,
const kernel::Kernel *kernel);
Status CustomReg(const std::string &provider, const std::string &op_type,
const registry::KernelInterfaceCreator creator);
Status Reg(const std::string &provider, int op_type, const registry::KernelInterfaceCreator creator);
virtual ~KernelInterfaceRegistry();

private:
@@ -45,7 +47,8 @@ class KernelInterfaceRegistry {
std::shared_ptr<kernel::KernelInterface> GetCacheInterface(const std::string &provider, int op_type);
std::shared_ptr<kernel::KernelInterface> GetCustomCacheInterface(const std::string &provider,
const std::string &type);
std::shared_ptr<kernel::KernelInterface> GetCustomKernelInterface(const schema::Primitive *primitive);
std::shared_ptr<kernel::KernelInterface> GetCustomKernelInterface(const schema::Primitive *primitive,
const kernel::Kernel *kernel);

std::mutex mutex_;
// key: provider


+ 2
- 2
mindspore/lite/src/registry/register_kernel.cc View File

@@ -23,7 +23,7 @@
namespace mindspore {
namespace registry {
Status RegisterKernel::RegCustomKernel(const std::vector<char> &arch, const std::vector<char> &provider,
DataType data_type, const std::vector<char> &type, CreateKernel creator) {
DataType data_type, const std::vector<char> &type, const CreateKernel creator) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return RegistryKernelImpl::GetInstance()->RegCustomKernel(CharToString(arch), CharToString(provider), data_type,
CharToString(type), creator);
@@ -34,7 +34,7 @@ Status RegisterKernel::RegCustomKernel(const std::vector<char> &arch, const std:
}

Status RegisterKernel::RegKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
int op_type, CreateKernel creator) {
int op_type, const CreateKernel creator) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return RegistryKernelImpl::GetInstance()->RegKernel(CharToString(arch), CharToString(provider), data_type, op_type,
creator);


+ 52
- 25
mindspore/lite/src/registry/register_kernel_impl.cc View File

@@ -25,15 +25,14 @@ using mindspore::schema::PrimitiveType_MAX;
using mindspore::schema::PrimitiveType_MIN;
namespace mindspore::registry {
namespace {
static const auto kKernelMaxNum =
(static_cast<int>(DataType::kNumberTypeEnd) - static_cast<int>(DataType::kNumberTypeBegin) - 1) *
(PrimitiveType_MAX - PrimitiveType_MIN);
static const auto kOpTypeLen = PrimitiveType_MAX - PrimitiveType_MIN + 1;
static const auto kDataTypeLen =
static_cast<int>(DataType::kNumberTypeEnd) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
static const auto kOpTypeLen = PrimitiveType_MAX - PrimitiveType_MIN;
} // namespace

int RegistryKernelImpl::GetFuncIndex(const KernelDesc &desc) {
static const auto kKernelMaxNum = kOpTypeLen * kDataTypeLen;
static constexpr auto kMaxProviderNum = 10;
static constexpr auto kMaxArchPerProviderNum = 10;
static constexpr auto kMaxCustomTypeNum = 200;
int GetFuncIndex(const KernelDesc &desc) {
if (desc.data_type >= DataType::kNumberTypeEnd) {
return -1;
}
@@ -47,14 +46,36 @@ int RegistryKernelImpl::GetFuncIndex(const KernelDesc &desc) {
}
return index;
}
} // namespace

Status RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator) {
if (data_type >= DataType::kNumberTypeEnd) {
const std::string &type, const CreateKernel creator) {
int data_type_index = static_cast<int>(data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
if (data_type_index < 0 || data_type_index >= kDataTypeLen) {
MS_LOG(ERROR) << "invalid data_type: " << static_cast<int>(data_type) << "!provider: " << provider;
return kLiteError;
}
std::unique_lock<std::mutex> lock(lock_);
auto provider_iter = custom_kernel_creators_.find(provider);
if (provider_iter == custom_kernel_creators_.end() && custom_kernel_creators_.size() >= kMaxProviderNum) {
MS_LOG(ERROR) << "register too many provider!";
return kLiteError;
}
if (provider_iter != custom_kernel_creators_.end()) {
auto arch_iter = provider_iter->second.find(arch);
if (arch_iter == provider_iter->second.end()) {
if (provider_iter->second.size() >= kMaxArchPerProviderNum) {
MS_LOG(ERROR) << "register too many arch!";
return kLiteError;
}
} else {
auto type_iter = arch_iter->second.find(type);
if (type_iter == arch_iter->second.end() && arch_iter->second.size() >= kMaxCustomTypeNum) {
MS_LOG(ERROR) << "register too many type!";
return kLiteError;
}
}
}
if (custom_kernel_creators_[provider][arch][type] == nullptr) {
custom_kernel_creators_[provider][arch][type] =
reinterpret_cast<CreateKernel *>(calloc(kDataTypeLen, sizeof(CreateKernel)));
@@ -64,20 +85,30 @@ Status RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::s
}
}

int data_type_index = static_cast<int>(data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
if (data_type_index < 0 || data_type_index >= kDataTypeLen) {
MS_LOG(ERROR) << "invalid data_type: " << static_cast<int>(data_type) << "!provider: " << provider;
return kLiteError;
}
custom_kernel_creators_[provider][arch][type][data_type_index] = creator;
return kSuccess;
}

Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
registry::CreateKernel creator) {
const registry::CreateKernel creator) {
if (type <= static_cast<int>(PrimitiveType_MIN) || type > static_cast<int>(PrimitiveType_MAX)) {
MS_LOG(ERROR) << "Invalid op type : " << type;
return kLiteParamInvalid;
}
KernelDesc desc = {data_type, type, arch, provider};
int index = GetFuncIndex(desc);
if (index < 0) {
MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << static_cast<int>(data_type) << ",op type "
<< type;
return kLiteError;
}
std::unique_lock<std::mutex> lock(lock_);
auto iter = kernel_creators_.find(provider);
if (iter == kernel_creators_.end()) {
if (kernel_creators_.size() >= kMaxProviderNum) {
MS_LOG(ERROR) << "register too many provider!";
return kLiteError;
}
kernel_creators_[provider][arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel)));
if (kernel_creators_[provider][arch] == nullptr) {
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
@@ -86,6 +117,10 @@ Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string
} else {
auto iter_arch = iter->second.find(arch);
if (iter_arch == iter->second.end()) {
if (iter->second.size() >= kMaxArchPerProviderNum) {
MS_LOG(ERROR) << "register too many arch!";
return kLiteError;
}
iter->second[arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel)));
if (iter->second[arch] == nullptr) {
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
@@ -94,14 +129,6 @@ Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string
}
}

KernelDesc desc = {data_type, type, arch, provider};
int index = GetFuncIndex(desc);
if (index < 0) {
MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << static_cast<int>(data_type) << ",op type "
<< type;
return kLiteError;
}

kernel_creators_[provider][arch][index] = creator;
return kSuccess;
}
@@ -109,11 +136,11 @@ Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string
registry::CreateKernel RegistryKernelImpl::GetCustomKernelCreator(const schema::Primitive *primitive,
KernelDesc *desc) {
int data_type_index = static_cast<int>(desc->data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
if (data_type_index < 0 || data_type_index >= kDataTypeLen) {
if (data_type_index < 0 || desc->data_type >= DataType::kNumberTypeEnd) {
return nullptr;
}
auto param = primitive->value_as_Custom();
if (param == nullptr) {
if (param == nullptr || param->type() == nullptr) {
return nullptr;
}
auto custom_type = param->type()->str();


+ 2
- 3
mindspore/lite/src/registry/register_kernel_impl.h View File

@@ -37,10 +37,10 @@ class RegistryKernelImpl {
}

Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, registry::CreateKernel creator);
const std::string &type, const registry::CreateKernel creator);

Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
registry::CreateKernel creator);
const registry::CreateKernel creator);

virtual registry::CreateKernel GetProviderCreator(const schema::Primitive *primitive, registry::KernelDesc *desc);

@@ -60,7 +60,6 @@ class RegistryKernelImpl {
std::mutex lock_;

registry::CreateKernel GetCustomKernelCreator(const schema::Primitive *primitive, registry::KernelDesc *desc);
int GetFuncIndex(const registry::KernelDesc &desc);
};
} // namespace mindspore::registry



+ 7
- 5
mindspore/lite/src/registry/register_kernel_interface.cc View File

@@ -22,7 +22,8 @@

namespace mindspore {
namespace registry {
Status RegisterKernelInterface::Reg(const std::vector<char> &provider, int op_type, KernelInterfaceCreator creator) {
Status RegisterKernelInterface::Reg(const std::vector<char> &provider, int op_type,
const KernelInterfaceCreator creator) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return KernelInterfaceRegistry::Instance()->Reg(CharToString(provider), op_type, creator);
#else
@@ -32,7 +33,7 @@ Status RegisterKernelInterface::Reg(const std::vector<char> &provider, int op_ty
}

Status RegisterKernelInterface::CustomReg(const std::vector<char> &provider, const std::vector<char> &op_type,
KernelInterfaceCreator creator) {
const KernelInterfaceCreator creator) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return KernelInterfaceRegistry::Instance()->CustomReg(CharToString(provider), CharToString(op_type), creator);
#else
@@ -41,10 +42,11 @@ Status RegisterKernelInterface::CustomReg(const std::vector<char> &provider, con
#endif
}

std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
const std::vector<char> &provider, const schema::Primitive *primitive) {
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(const std::vector<char> &provider,
const schema::Primitive *primitive,
const kernel::Kernel *kernel) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return KernelInterfaceRegistry::Instance()->GetKernelInterface(CharToString(provider), primitive);
return KernelInterfaceRegistry::Instance()->GetKernelInterface(CharToString(provider), primitive, kernel);
#else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return nullptr;


+ 16
- 6
mindspore/lite/src/runtime/infer_manager.cc View File

@@ -34,23 +34,33 @@ namespace mindspore {
namespace lite {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const void *primitive, std::set<std::string> &&providers, int schema_version) {
if (primitive == nullptr) {
const void *primitive, std::set<std::string> &&providers, int schema_version,
const kernel::Kernel *kernel) {
if (primitive == nullptr && kernel == nullptr) {
return RET_NOT_SUPPORT;
}
std::shared_ptr<kernel::KernelInterface> kernel_interface = nullptr;
if (IsCustomNode(primitive, schema_version)) {
kernel_interface =
registry::RegisterKernelInterface::GetKernelInterface("", static_cast<const schema::Primitive *>(primitive));
bool is_custom_node = false;
if (kernel == nullptr) {
if (IsCustomNode(primitive, schema_version)) {
is_custom_node = true;
}
} else if (kernel->type() == schema::PrimitiveType_Custom) {
is_custom_node = true;
}
if (is_custom_node) {
kernel_interface = registry::RegisterKernelInterface::GetKernelInterface(
"", static_cast<const schema::Primitive *>(primitive), kernel);
} else {
for (auto &&provider : providers) {
kernel_interface = registry::RegisterKernelInterface::GetKernelInterface(
provider, static_cast<const schema::Primitive *>(primitive));
provider, static_cast<const schema::Primitive *>(primitive), kernel);
if (kernel_interface != nullptr) {
break;
}
}
}

if (kernel_interface == nullptr) {
return RET_NOT_SUPPORT;
}


+ 3
- 1
mindspore/lite/src/runtime/infer_manager.h View File

@@ -26,13 +26,15 @@
#include "src/tensor.h"
#include "nnacl/tensor_c.h"
#include "nnacl/infer/infer.h"
#include "include/api/kernel.h"

namespace mindspore::lite {
int KernelInferShape(const std::vector<lite::Tensor *> &tensors_in, const std::vector<lite::Tensor *> &outputs,
OpParameter *parameter);
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const void *primitive, std::set<std::string> &&providers, int schema_version);
const void *primitive, std::set<std::string> &&providers, int schema_version,
const kernel::Kernel *kernel = nullptr);
#endif
class InferManager {
public:


+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc View File

@@ -428,7 +428,7 @@ int TryFusionConvScaleWeight(LiteKernel *conv_kernel, LiteKernel *scale_kernel)
MS_ASSERT(conv_kernel);
MS_ASSERT(scale_kernel);
auto *scale_param =
reinterpret_cast<ScaleParameter *>(reinterpret_cast<OpenCLKernel *>(scale_kernel)->GetParameter());
reinterpret_cast<ScaleParameter *>(reinterpret_cast<OpenCLKernel *>(scale_kernel->kernel())->GetParameter());
MS_ASSERT(scale_param);
MS_ASSERT(conv_kernel->in_tensors().size() >= INPUT_TENSOR_SIZE_2);
auto *filter = conv_kernel->in_tensors().at(1);


+ 3
- 0
mindspore/lite/src/scheduler.cc View File

@@ -1373,6 +1373,9 @@ kernel::LiteKernel *Scheduler::ScheduleNodeToKernel(const lite::Model::Node *src

SetKernelTensorDataType(kernel);
kernel->set_name(src_node->name_);
if (kernel->kernel() != nullptr) {
kernel->kernel()->SetConfig(config_info_);
}
return kernel;
}



+ 4
- 0
mindspore/lite/src/scheduler.h View File

@@ -59,6 +59,9 @@ class Scheduler {
~Scheduler() = default;
int Schedule(std::vector<kernel::LiteKernel *> *dst_kernels);
void SetupSchedulerCb(std::unique_ptr<SchedulerCb> cb) { sched_cb_ = std::move(cb); }
void SetConfig(const std::map<std::string, std::map<std::string, std::string>> *config_info) {
config_info_ = config_info;
}

private:
int SchedulePreProcess();
@@ -165,6 +168,7 @@ class Scheduler {
#endif
int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR;
std::map<std::string, TypeId> *execution_plan_ = nullptr;
const std::map<std::string, std::map<std::string, std::string>> *config_info_ = nullptr;
};
} // namespace mindspore::lite



+ 1
- 1
mindspore/lite/src/sub_graph_kernel.cc View File

@@ -92,7 +92,7 @@ int SubGraphKernel::ReSize() {
int ret;
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
ret = lite::KernelInferShape(inputs, outputs, kernel->kernel()->primitive(), kernel->Context()->GetProviders(),
schema_version_);
schema_version_, kernel->kernel());
if (ret == lite::RET_NOT_SUPPORT) {
#endif
auto parameter = kernel->op_parameter();


+ 3
- 3
mindspore/lite/test/st/mix_data_type_test.cc View File

@@ -51,10 +51,10 @@ TEST_F(MixDataTypeTest, Config1) {

std::string filename = "MixDataTypeTestConfig";
std::string sectionname = "execution_plan";
std::map<std::string, std::string> config_info;
ret = lite::GetSectionInfoFromConfigFile(filename, sectionname, &config_info);
std::map<std::string, std::map<std::string, std::string>> configs;
ret = lite::GetAllSectionInfoFromConfigFile(filename, &configs);
ASSERT_EQ(ret, 0);
std::map<std::string, std::string> config_info = configs[sectionname];
ASSERT_EQ(config_info.size(), 2);

auto info0 = config_info.at("op1");


Loading…
Cancel
Save