| @@ -19,13 +19,14 @@ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <map> | |||
| #include "schema/model_generated.h" | |||
| #include "include/api/types.h" | |||
| #include "include/api/context.h" | |||
| namespace mindspore::kernel { | |||
| /// \brief The Kernel class is used to define a MindSpore Kernel. | |||
| class Kernel { | |||
| class MS_API Kernel { | |||
| public: | |||
| Kernel() = default; | |||
| /// \brief Constructor. | |||
| @@ -37,9 +38,7 @@ class Kernel { | |||
| Kernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs, | |||
| const schema::Primitive *primitive, const mindspore::Context *ctx) | |||
| : context_(ctx), inputs_(std::move(inputs)), outputs_(std::move(outputs)), primitive_(primitive) { | |||
| if (primitive != nullptr) { | |||
| type_ = primitive->value_type(); | |||
| } | |||
| Initialize(); | |||
| } | |||
| /// \brief Destructor. | |||
| virtual ~Kernel() = default; | |||
| @@ -102,6 +101,44 @@ class Kernel { | |||
| /// \return the primitive of kernel generated by flatbuffers. | |||
| const schema::Primitive *primitive() const { return this->primitive_; } | |||
| /// \brief get kernel's attribute. | |||
| /// | |||
| /// \param[in] key define the kernel's attribute key. | |||
| std::string GetAttr(const std::string &key) const { | |||
| auto iter = attrs_.find(key); | |||
| if (iter != attrs_.end()) { | |||
| return iter->second; | |||
| } | |||
| return ""; | |||
| } | |||
| /// \brief set kernel's config. | |||
| /// | |||
| /// \param[in] config define the kernel's config. | |||
| void SetConfig(const std::map<std::string, std::map<std::string, std::string>> *config) { | |||
| config_ = config; | |||
| } | |||
| /// \brief set kernel's config. | |||
| /// | |||
| /// \param[in] config define the kernel's config. | |||
| std::map<std::string, std::string> GetConfig(const std::string §ion) const { | |||
| if (config_ == nullptr) { | |||
| return std::map<std::string, std::string>(); | |||
| } | |||
| auto iter = config_->find(section); | |||
| if (iter != config_->end()) { | |||
| return iter->second; | |||
| } | |||
| return std::map<std::string, std::string>(); | |||
| } | |||
| protected: | |||
| /// \brief set kernel's attribute | |||
| /// | |||
| /// \param[in] key define the kernel's attribute key. | |||
| /// \param[in] value define the kernel's attribute value. | |||
| void SetAttr(const std::string &key, const std::string &value) { attrs_[key] = value; } | |||
| protected: | |||
| std::string name_; | |||
| const mindspore::Context *context_ = nullptr; | |||
| @@ -109,6 +146,11 @@ class Kernel { | |||
| std::vector<mindspore::MSTensor> outputs_; | |||
| schema::PrimitiveType type_ = schema::PrimitiveType_NONE; | |||
| const schema::Primitive *primitive_ = nullptr; | |||
| std::map<std::string, std::string> attrs_; | |||
| const std::map<std::string, std::map<std::string, std::string>> *config_; | |||
| private: | |||
| void Initialize(); | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -106,6 +106,14 @@ class MS_API Model { | |||
| /// \return Status. | |||
| inline Status LoadConfig(const std::string &config_path); | |||
| /// \brief Update config. | |||
| /// | |||
| /// \param[in] section define the config section. | |||
| /// \param[in] config define the config will be updated. | |||
| /// | |||
| /// \return Status. | |||
| inline Status UpdateConfig(const std::string §ion, const std::pair<std::string, std::string> &config); | |||
| /// \brief Obtains all input tensors of the model. | |||
| /// | |||
| /// \return The vector that includes all input tensors. | |||
| @@ -215,6 +223,7 @@ class MS_API Model { | |||
| MSTensor GetOutputByTensorName(const std::vector<char> &tensor_name); | |||
| std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name); | |||
| Status LoadConfig(const std::vector<char> &config_path); | |||
| Status UpdateConfig(const std::vector<char> §ion, const std::pair<std::vector<char>, std::vector<char>> &config); | |||
| Status Build(const void *model_data, size_t data_size, ModelType model_type, | |||
| const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::vector<char> &dec_mode); | |||
| Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context, | |||
| @@ -241,6 +250,12 @@ Status Model::LoadConfig(const std::string &config_path) { | |||
| return LoadConfig(StringToChar(config_path)); | |||
| } | |||
| Status Model::UpdateConfig(const std::string §ion, const std::pair<std::string, std::string> &config) { | |||
| std::pair<std::vector<char>, std::vector<char>> config_pair = {StringToChar(config.first), | |||
| StringToChar(config.second)}; | |||
| return UpdateConfig(StringToChar(section), config_pair); | |||
| } | |||
| Status Model::Build(const void *model_data, size_t data_size, ModelType model_type, | |||
| const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) { | |||
| return Build(model_data, data_size, model_type, model_context, dec_key, StringToChar(dec_mode)); | |||
| @@ -71,7 +71,7 @@ class MS_API RegisterKernel { | |||
| /// | |||
| /// \return Status as a status identification of registering. | |||
| inline static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type, | |||
| CreateKernel creator); | |||
| const CreateKernel creator); | |||
| /// \brief Static method to register kernel which is corresponding to custom op. | |||
| /// | |||
| @@ -83,7 +83,7 @@ class MS_API RegisterKernel { | |||
| /// | |||
| /// \return Status as a status identification of registering. | |||
| inline static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, | |||
| const std::string &type, CreateKernel creator); | |||
| const std::string &type, const CreateKernel creator); | |||
| /// \brief Static methon to get a kernel's create function. | |||
| /// | |||
| @@ -95,9 +95,9 @@ class MS_API RegisterKernel { | |||
| private: | |||
| static Status RegKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type, | |||
| int type, CreateKernel creator); | |||
| int type, const CreateKernel creator); | |||
| static Status RegCustomKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type, | |||
| const std::vector<char> &type, CreateKernel creator); | |||
| const std::vector<char> &type, const CreateKernel creator); | |||
| static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDescHelper *desc); | |||
| }; | |||
| @@ -115,7 +115,7 @@ class MS_API KernelReg { | |||
| /// \param[in] op_type Define the ordinary op type. | |||
| /// \param[in] creator Define a function pointer to create a kernel. | |||
| KernelReg(const std::string &arch, const std::string &provider, DataType data_type, int op_type, | |||
| CreateKernel creator) { | |||
| const CreateKernel creator) { | |||
| RegisterKernel::RegKernel(arch, provider, data_type, op_type, creator); | |||
| } | |||
| @@ -127,18 +127,18 @@ class MS_API KernelReg { | |||
| /// \param[in] op_type Define the concrete type of a custom op. | |||
| /// \param[in] creator Define a function pointer to create a kernel. | |||
| KernelReg(const std::string &arch, const std::string &provider, DataType data_type, const std::string &op_type, | |||
| CreateKernel creator) { | |||
| const CreateKernel creator) { | |||
| RegisterKernel::RegCustomKernel(arch, provider, data_type, op_type, creator); | |||
| } | |||
| }; | |||
| Status RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type, | |||
| CreateKernel creator) { | |||
| const CreateKernel creator) { | |||
| return RegKernel(StringToChar(arch), StringToChar(provider), data_type, type, creator); | |||
| } | |||
| Status RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, | |||
| const std::string &type, CreateKernel creator) { | |||
| const std::string &type, const CreateKernel creator) { | |||
| return RegCustomKernel(StringToChar(arch), StringToChar(provider), data_type, StringToChar(type), creator); | |||
| } | |||
| @@ -25,6 +25,9 @@ | |||
| #include "schema/model_generated.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class Kernel; | |||
| } | |||
| namespace registry { | |||
| /// \brief KernelInterfaceCreator defined a functor to create KernelInterface. | |||
| using KernelInterfaceCreator = std::function<std::shared_ptr<kernel::KernelInterface>()>; | |||
| @@ -40,7 +43,7 @@ class MS_API RegisterKernelInterface { | |||
| /// | |||
| /// \return Status as a status identification of registering. | |||
| inline static Status CustomReg(const std::string &provider, const std::string &op_type, | |||
| KernelInterfaceCreator creator); | |||
| const KernelInterfaceCreator creator); | |||
| /// \brief Static method to register op whose primitive type is ordinary. | |||
| /// | |||
| @@ -49,23 +52,26 @@ class MS_API RegisterKernelInterface { | |||
| /// \param[in] creator Define the KernelInterface create function. | |||
| /// | |||
| /// \return Status as a status identification of registering. | |||
| inline static Status Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator); | |||
| inline static Status Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator); | |||
| /// \brief Static method to get registration of a certain op. | |||
| /// | |||
| /// \param[in] provider Define the identification of user. | |||
| /// \param[in] primitive Define the attributes of a certain op. | |||
| /// \param[in] kernel Define the kernel of a certain op. | |||
| /// | |||
| /// \return Boolean value to represent registration of a certain op is existing or not. | |||
| inline static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider, | |||
| const schema::Primitive *primitive); | |||
| const schema::Primitive *primitive, | |||
| const kernel::Kernel *kernel = nullptr); | |||
| private: | |||
| static Status CustomReg(const std::vector<char> &provider, const std::vector<char> &op_type, | |||
| KernelInterfaceCreator creator); | |||
| static Status Reg(const std::vector<char> &provider, int op_type, KernelInterfaceCreator creator); | |||
| const KernelInterfaceCreator creator); | |||
| static Status Reg(const std::vector<char> &provider, int op_type, const KernelInterfaceCreator creator); | |||
| static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::vector<char> &provider, | |||
| const schema::Primitive *primitive); | |||
| const schema::Primitive *primitive, | |||
| const kernel::Kernel *kernel = nullptr); | |||
| }; | |||
| /// \brief KernelInterfaceReg defined registration class of KernelInterface. | |||
| @@ -76,7 +82,7 @@ class MS_API KernelInterfaceReg { | |||
| /// \param[in] provider Define the identification of user. | |||
| /// \param[in] op_type Define the ordinary op type. | |||
| /// \param[in] creator Define the KernelInterface create function. | |||
| KernelInterfaceReg(const std::string &provider, int op_type, KernelInterfaceCreator creator) { | |||
| KernelInterfaceReg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) { | |||
| RegisterKernelInterface::Reg(provider, op_type, creator); | |||
| } | |||
| @@ -85,23 +91,26 @@ class MS_API KernelInterfaceReg { | |||
| /// \param[in] provider Define the identification of user. | |||
| /// \param[in] op_type Define the concrete type of a custom op. | |||
| /// \param[in] creator Define the KernelInterface create function. | |||
| KernelInterfaceReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator) { | |||
| KernelInterfaceReg(const std::string &provider, const std::string &op_type, const KernelInterfaceCreator creator) { | |||
| RegisterKernelInterface::CustomReg(provider, op_type, creator); | |||
| } | |||
| virtual ~KernelInterfaceReg() = default; | |||
| }; | |||
| Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type, | |||
| KernelInterfaceCreator creator) { | |||
| const KernelInterfaceCreator creator) { | |||
| return CustomReg(StringToChar(provider), StringToChar(op_type), creator); | |||
| } | |||
| Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) { | |||
| Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) { | |||
| return Reg(StringToChar(provider), op_type, creator); | |||
| } | |||
| std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface( | |||
| const std::string &provider, const schema::Primitive *primitive) { | |||
| return GetKernelInterface(StringToChar(provider), primitive); | |||
| std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(const std::string &provider, | |||
| const schema::Primitive *primitive, | |||
| const kernel::Kernel *kernel) { | |||
| return GetKernelInterface(StringToChar(provider), primitive, kernel); | |||
| } | |||
| /// \brief Defined registering macro to register ordinary op, which called by user directly. | |||
| @@ -21,19 +21,55 @@ | |||
| #endif | |||
| namespace { | |||
| constexpr size_t kLengthOfParentheses = 2; | |||
| } | |||
| constexpr size_t kMinSectionLineLength = 2; | |||
| constexpr size_t kMaxValidLineCount = 100000; | |||
| constexpr size_t kMaxLineCount = 100100; | |||
| } // namespace | |||
| namespace mindspore { | |||
| namespace lite { | |||
| int GetSectionInfoFromConfigFile(const std::string &file, const std::string §ion_name, | |||
| std::map<std::string, std::string> *section_info) { | |||
| if (file.empty()) { | |||
| MS_LOG(ERROR) << "file is nullptr"; | |||
| namespace { | |||
| void ParseLine(const std::string &line, std::map<std::string, std::string> *section_config, std::string *section, | |||
| size_t *valid_line_count, std::map<std::string, std::map<std::string, std::string>> *config) { | |||
| // eg: [section] | |||
| // key=value | |||
| if (line[0] == '[' && line[line.length() - 1] == ']') { | |||
| if (!section->empty() && !section_config->empty()) { | |||
| config->insert(std::make_pair(*section, *section_config)); | |||
| } | |||
| section_config->clear(); | |||
| *section = line.substr(1, line.length() - kLengthOfParentheses); | |||
| *valid_line_count = *valid_line_count + 1; | |||
| } | |||
| if (!section->empty()) { | |||
| auto index = line.find('='); | |||
| if (index == std::string::npos) { | |||
| return; | |||
| } | |||
| auto key = line.substr(0, index); | |||
| if (index + 1 > line.size()) { | |||
| return; | |||
| } | |||
| auto value = line.substr(index + 1); | |||
| lite::Trim(&key); | |||
| lite::Trim(&value); | |||
| section_config->insert(std::make_pair(key, value)); | |||
| *valid_line_count = *valid_line_count + 1; | |||
| } | |||
| } | |||
| } // namespace | |||
| int GetAllSectionInfoFromConfigFile(const std::string &file, | |||
| std::map<std::string, std::map<std::string, std::string>> *config) { | |||
| if (file.empty() || config == nullptr) { | |||
| MS_LOG(ERROR) << "input Invalid!check file and config."; | |||
| return RET_ERROR; | |||
| } | |||
| auto resolved_path = std::make_unique<char[]>(PATH_MAX); | |||
| if (resolved_path == nullptr) { | |||
| MS_LOG(ERROR) << "new resolved_path failed"; | |||
| MS_LOG(ERROR) << "new resolved_path fail!"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -56,44 +92,25 @@ int GetSectionInfoFromConfigFile(const std::string &file, const std::string &sec | |||
| return RET_ERROR; | |||
| } | |||
| std::string line; | |||
| bool find_section = false; | |||
| std::string section; | |||
| std::map<std::string, std::string> section_config; | |||
| size_t line_count = 0; | |||
| size_t valid_line_count = 0; | |||
| while (std::getline(ifs, line)) { | |||
| lite::Trim(&line); | |||
| if (line.empty()) { | |||
| continue; | |||
| line_count++; | |||
| if (line_count >= kMaxLineCount || valid_line_count >= kMaxValidLineCount) { | |||
| MS_LOG(ERROR) << "config too many lines!"; | |||
| return RET_ERROR; | |||
| } | |||
| if (line[0] == '#') { | |||
| lite::Trim(&line); | |||
| if (line.length() <= kMinSectionLineLength || line[0] == '#') { | |||
| continue; | |||
| } | |||
| if (line[0] == '[') { | |||
| if (find_section == true) { | |||
| break; | |||
| } | |||
| std::string section = line.substr(1, line.length() - kLengthOfParentheses); | |||
| if (section != section_name) { | |||
| continue; | |||
| } | |||
| find_section = true; | |||
| } | |||
| if (find_section == true) { | |||
| auto index = line.find('='); | |||
| if (index == std::string::npos) { | |||
| continue; | |||
| } | |||
| auto key = line.substr(0, index); | |||
| if (index + 1 > line.size()) { | |||
| return RET_ERROR; | |||
| } | |||
| auto value = line.substr(index + 1); | |||
| lite::Trim(&key); | |||
| lite::Trim(&value); | |||
| section_info->insert(std::make_pair(key, value)); | |||
| } | |||
| ParseLine(line, §ion_config, §ion, &valid_line_count, config); | |||
| } | |||
| if (!section.empty() && !section_config.empty()) { | |||
| config->insert(std::make_pair(section, section_config)); | |||
| } | |||
| ifs.close(); | |||
| return RET_OK; | |||
| } | |||
| @@ -35,10 +35,8 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| constexpr int MAX_CONFIG_FILE_LENGTH = 1024; | |||
| #define CONFIG_FILE_EXECUTION_PLAN "execution_plan" | |||
| int GetSectionInfoFromConfigFile(const std::string &file, const std::string §ion_name, | |||
| std::map<std::string, std::string> *section_info); | |||
| int GetAllSectionInfoFromConfigFile(const std::string &file, | |||
| std::map<std::string, std::map<std::string, std::string>> *config); | |||
| void ParserExecutionPlan(const std::map<std::string, std::string> *config_infos, | |||
| std::map<std::string, TypeId> *data_type_plan); | |||
| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "include/api/kernel.h" | |||
| namespace mindspore::kernel { | |||
| void Kernel::Initialize() { | |||
| if (primitive_ == nullptr) { | |||
| return; | |||
| } | |||
| type_ = primitive_->value_type(); | |||
| if (type_ == schema::PrimitiveType_Custom) { | |||
| auto param = primitive_->value_as_Custom(); | |||
| if (param != nullptr && param->type() != nullptr) { | |||
| SetAttr("type", param->type()->str()); | |||
| } | |||
| } | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -209,6 +209,19 @@ Status Model::LoadConfig(const std::vector<char> &config_path) { | |||
| return kSuccess; | |||
| } | |||
| Status Model::UpdateConfig(const std::vector<char> §ion, | |||
| const std::pair<std::vector<char>, std::vector<char>> &config) { | |||
| std::unique_lock<std::mutex> impl_lock(g_impl_init_lock); | |||
| if (impl_ == nullptr) { | |||
| impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl()); | |||
| } | |||
| if (impl_ != nullptr) { | |||
| return impl_->UpdateConfig(CharToString(section), {CharToString(config.first), CharToString(config.second)}); | |||
| } | |||
| MS_LOG(ERROR) << "Model implement is null!"; | |||
| return kLiteFileError; | |||
| } | |||
| Status Model::SetTrainMode(bool train) { | |||
| if ((impl_ == nullptr) || (impl_->session_ == nullptr)) { | |||
| MS_LOG(ERROR) << "Model is null."; | |||
| @@ -17,6 +17,10 @@ | |||
| #include "src/cxx_api/model/model_impl.h" | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <map> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "include/api/types.h" | |||
| #include "include/api/context.h" | |||
| #include "include/lite_session.h" | |||
| @@ -32,6 +36,11 @@ | |||
| #include "src/common/config_file.h" | |||
| namespace mindspore { | |||
| namespace { | |||
| static const char *kExecutionPlan = "execution_plan"; | |||
| static constexpr size_t kMaxSectionNum = 100; | |||
| static constexpr size_t kMaxConfigNumPerSection = 1000; | |||
| } // namespace | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| @@ -195,15 +204,16 @@ Status ModelImpl::RunGraph(const MSKernelCallBack &before, const MSKernelCallBac | |||
| bool ModelImpl::IsTrainModel() { return (graph_ && graph_->graph_data_ && graph_->graph_data_->IsTrainModel()); } | |||
| Status ModelImpl::LoadConfig(const std::string &config_path) { | |||
| std::map<std::string, std::string> config_info; | |||
| int ret = lite::GetSectionInfoFromConfigFile(config_path, CONFIG_FILE_EXECUTION_PLAN, &config_info); | |||
| std::map<std::string, std::map<std::string, std::string>> all_config_info; | |||
| int ret = lite::GetAllSectionInfoFromConfigFile(config_path, &all_config_info); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "GetSectionInfoFromConfigFile failed."; | |||
| MS_LOG(ERROR) << "GetAllSectionInfoFromConfigFile fail!ret: " << ret; | |||
| return kLiteFileError; | |||
| } | |||
| config_info_ = all_config_info; | |||
| std::map<std::string, std::string> config_info = all_config_info[kExecutionPlan]; | |||
| if (config_info.empty()) { | |||
| MS_LOG(WARNING) << "No valid info in config file."; | |||
| MS_LOG(WARNING) << "No valid execution plan info in config file."; | |||
| return kSuccess; | |||
| } | |||
| @@ -211,6 +221,24 @@ Status ModelImpl::LoadConfig(const std::string &config_path) { | |||
| return kSuccess; | |||
| } | |||
| Status ModelImpl::UpdateConfig(const std::string §ion, const std::pair<std::string, std::string> &config) { | |||
| auto iter = config_info_.find(section); | |||
| if (iter == config_info_.end()) { | |||
| if (config_info_.size() >= kMaxSectionNum) { | |||
| MS_LOG(ERROR) << "config too many sections!"; | |||
| return kLiteError; | |||
| } | |||
| config_info_[section][config.first] = config.second; | |||
| return kSuccess; | |||
| } | |||
| if (iter->second.size() >= kMaxConfigNumPerSection) { | |||
| MS_LOG(ERROR) << "config too many items!"; | |||
| return kLiteError; | |||
| } | |||
| iter->second[config.first] = config.second; | |||
| return kSuccess; | |||
| } | |||
| Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, | |||
| const MSKernelCallBack &before, const MSKernelCallBack &after) { | |||
| if (outputs == nullptr) { | |||
| @@ -567,6 +595,7 @@ session::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context) | |||
| } | |||
| session->InitExecutionConfig(&execution_plan_); | |||
| session->SetConfigInfo(&config_info_); | |||
| auto ret = session->Init(context); | |||
| if (ret != mindspore::lite::RET_OK) { | |||
| @@ -70,6 +70,7 @@ class ModelImpl { | |||
| session::LiteSession *CreateLiteSession(lite::InnerContext *context); | |||
| Status LoadConfig(const std::string &config_path); | |||
| Status UpdateConfig(const std::string §ion, const std::pair<std::string, std::string> &config); | |||
| std::vector<MSTensor> GetInputs(); | |||
| std::vector<MSTensor> GetOutputs(); | |||
| std::vector<MSTensor> GetGradients() const; | |||
| @@ -112,6 +113,7 @@ class ModelImpl { | |||
| void SetConfig(const std::shared_ptr<TrainCfg> cfg) { cfg_ = cfg; } | |||
| Status RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after); | |||
| std::map<std::string, TypeId> execution_plan_; | |||
| std::map<std::string, std::map<std::string, std::string>> config_info_; | |||
| }; | |||
| } // namespace mindspore | |||
| @@ -523,6 +523,7 @@ int LiteSession::CompileGraph(Model *model) { | |||
| Scheduler scheduler(context_, ms_context_, model, &tensors_, inputs_, outputs_, is_train_session_, &is_infershape_, | |||
| &is_control_flow_, execution_plan_, delegate_, delegate_device_type_); | |||
| scheduler.SetupSchedulerCb(std::move(sched_cb_)); | |||
| scheduler.SetConfig(config_info_); | |||
| ret = scheduler.Schedule(&kernels_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Schedule kernels failed: " << ret; | |||
| @@ -87,6 +87,10 @@ class LiteSession : public session::LiteSession { | |||
| const Delegate *get_delegate() const { return this->delegate_.get(); } | |||
| void SetConfigInfo(const std::map<std::string, std::map<std::string, std::string>> *config_info) { | |||
| config_info_ = config_info; | |||
| } | |||
| protected: | |||
| static void ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lite::Tensor *dst_tensor); | |||
| @@ -182,6 +186,7 @@ class LiteSession : public session::LiteSession { | |||
| std::shared_ptr<Delegate> delegate_ = nullptr; | |||
| int delegate_device_type_ = -1; // -1: not specified; 0: CPU; 1: GPU; 2: NPU | |||
| std::map<std::string, TypeId> *execution_plan_ = nullptr; | |||
| const std::map<std::string, std::map<std::string, std::string>> *config_info_ = nullptr; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -70,14 +70,14 @@ OpParameter *PopulateAffineParameter(const void *prim) { | |||
| affine_param->context_size_ = static_cast<int>(context.size()); | |||
| // malloc && memset for context | |||
| affine_param->context_ = reinterpret_cast<int *>(malloc(affine_param->context_size_ * sizeof(int))); | |||
| affine_param->context_ = reinterpret_cast<int *>(malloc(context.size() * sizeof(int))); | |||
| if (affine_param->context_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc param context_ for affine layer failed!"; | |||
| ReleaseParam(affine_param, matmul_param); | |||
| return nullptr; | |||
| } | |||
| memset(affine_param->context_, 0, affine_param->context_size_ * sizeof(int)); | |||
| for (int i = 0; i < affine_param->context_size_; ++i) { | |||
| (void)memset(affine_param->context_, 0, context.size() * sizeof(int)); | |||
| for (size_t i = 0; i < context.size(); ++i) { | |||
| affine_param->context_[i] = context.at(i); | |||
| } | |||
| affine_param->output_dim_ = value->output_dim(); | |||
| @@ -43,8 +43,8 @@ OpParameter *PopulateTensorArrayParameter(const void *prim) { | |||
| bool identical_element_shapes = value->identical_element_shapes(); | |||
| param->identical_element_shapes_ = identical_element_shapes; | |||
| std::vector<int> primitive_element_shape(value->element_shape()->begin(), value->element_shape()->end()); | |||
| param->element_shape_size_ = primitive_element_shape.size(); | |||
| int size = sizeof(int) * param->element_shape_size_; | |||
| param->element_shape_size_ = static_cast<int>(primitive_element_shape.size()); | |||
| auto size = sizeof(int) * param->element_shape_size_; | |||
| param->element_shape_ = static_cast<int *>(malloc(size)); | |||
| if (param->element_shape_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc element_shape failed!"; | |||
| @@ -52,7 +52,7 @@ OpParameter *PopulateSpliceParameter(const void *prim) { | |||
| param->context_dim_ = static_cast<int>(primitive_context.size()); | |||
| // malloc && memset for context | |||
| param->context_ = reinterpret_cast<int *>(malloc(param->context_dim_ * sizeof(int))); | |||
| param->context_ = reinterpret_cast<int *>(malloc(primitive_context.size() * sizeof(int))); | |||
| if (param->context_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc param context_ error"; | |||
| free(param); | |||
| @@ -60,8 +60,8 @@ OpParameter *PopulateSpliceParameter(const void *prim) { | |||
| } | |||
| // src_to_dst_row_offset | |||
| int src_to_dst_row_offset = INT32_MIN; | |||
| memset(param->context_, 0, param->context_dim_ * sizeof(int)); | |||
| for (int i = 0; i < param->context_dim_; ++i) { | |||
| (void)memset(param->context_, 0, primitive_context.size() * sizeof(int)); | |||
| for (size_t i = 0; i < primitive_context.size(); ++i) { | |||
| param->context_[i] = primitive_context[i]; | |||
| src_to_dst_row_offset = std::max(src_to_dst_row_offset, std::abs(primitive_context.at(i))); | |||
| } | |||
| @@ -83,15 +83,15 @@ OpParameter *PopulateSpliceParameter(const void *prim) { | |||
| param->forward_indexes_dim_ = static_cast<int>(primitive_forward_indexes.size()); | |||
| // malloc && memset for forward_indexes | |||
| param->forward_indexes_ = reinterpret_cast<int *>(malloc(param->forward_indexes_dim_ * sizeof(int))); | |||
| param->forward_indexes_ = reinterpret_cast<int *>(malloc(primitive_context.size() * sizeof(int))); | |||
| if (param->forward_indexes_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc param forward_indexes_ error"; | |||
| free(param->context_); | |||
| free(param); | |||
| return nullptr; | |||
| } | |||
| memset(param->forward_indexes_, 0, param->forward_indexes_dim_ * sizeof(int)); | |||
| memcpy(param->forward_indexes_, primitive_forward_indexes.data(), param->forward_indexes_dim_ * sizeof(int)); | |||
| (void)memset(param->forward_indexes_, 0, primitive_context.size() * sizeof(int)); | |||
| (void)memcpy(param->forward_indexes_, primitive_forward_indexes.data(), primitive_context.size() * sizeof(int)); | |||
| param->output_dim_ = value->output_dim(); | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| @@ -20,6 +20,7 @@ | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/common/version_manager.h" | |||
| #include "schema/model_generated.h" | |||
| #include "include/api/kernel.h" | |||
| using mindspore::registry::KernelInterfaceCreator; | |||
| using mindspore::schema::PrimitiveType_MAX; | |||
| @@ -27,16 +28,33 @@ using mindspore::schema::PrimitiveType_MIN; | |||
| namespace mindspore { | |||
| namespace registry { | |||
| namespace { | |||
| static constexpr auto kMaxProviderNum = 10; | |||
| static constexpr auto KMaxCustomTypeNum = 200; | |||
| static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN + 1; | |||
| std::string GetCustomType(const schema::Primitive *primitive) { | |||
| auto param = primitive->value_as_Custom(); | |||
| MS_ASSERT(param != nullptr); | |||
| if (param == nullptr || param->type() == nullptr) { | |||
| return ""; | |||
| } | |||
| return param->type()->str(); | |||
| } | |||
| } // namespace | |||
| Status KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &type, | |||
| KernelInterfaceCreator creator) { | |||
| const KernelInterfaceCreator creator) { | |||
| auto provider_iter = custom_creators_.find(provider); | |||
| if (provider_iter == custom_creators_.end() && custom_creators_.size() >= kMaxProviderNum) { | |||
| MS_LOG(ERROR) << "register too many provider!"; | |||
| return kLiteError; | |||
| } | |||
| if (provider_iter != custom_creators_.end()) { | |||
| auto type_iter = provider_iter->second.find(type); | |||
| if (type_iter == provider_iter->second.end() && provider_iter->second.size() >= KMaxCustomTypeNum) { | |||
| MS_LOG(ERROR) << "register too many custom type!"; | |||
| return kLiteError; | |||
| } | |||
| } | |||
| custom_creators_[provider][type] = creator; | |||
| return kSuccess; | |||
| } | |||
| @@ -73,15 +91,19 @@ std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomCache | |||
| } | |||
| std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomKernelInterface( | |||
| const schema::Primitive *primitive) { | |||
| MS_ASSERT(primitive != nullptr); | |||
| const schema::Primitive *primitive, const kernel::Kernel *kernel) { | |||
| std::unique_lock<std::mutex> lock(mutex_); | |||
| auto &&type = GetCustomType(primitive); | |||
| std::string type; | |||
| if (kernel == nullptr) { | |||
| type = GetCustomType(primitive); | |||
| } else { | |||
| type = kernel->GetAttr("type"); | |||
| } | |||
| for (auto &&item : custom_creators_) { | |||
| auto &&provider = item.first; | |||
| auto kernel = GetCustomCacheInterface(provider, type); | |||
| if (kernel != nullptr) { | |||
| return kernel; | |||
| auto kernel_interface = GetCustomCacheInterface(provider, type); | |||
| if (kernel_interface != nullptr) { | |||
| return kernel_interface; | |||
| } | |||
| auto provider_iter = custom_creators_.find(provider); | |||
| if (provider_iter == custom_creators_.end()) { | |||
| @@ -89,47 +111,54 @@ std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomKerne | |||
| } | |||
| auto creator_iter = provider_iter->second.find(type); | |||
| if (creator_iter != provider_iter->second.end()) { | |||
| kernel = creator_iter->second(); | |||
| custom_kernels_[provider][type] = kernel; | |||
| return kernel; | |||
| kernel_interface = creator_iter->second(); | |||
| custom_kernels_[provider][type] = kernel_interface; | |||
| return kernel_interface; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInterface( | |||
| const std::string &provider, const schema::Primitive *primitive) { | |||
| if (primitive == nullptr) { | |||
| std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInterface(const std::string &provider, | |||
| const schema::Primitive *primitive, | |||
| const kernel::Kernel *kernel) { | |||
| if (primitive == nullptr && kernel == nullptr) { | |||
| return nullptr; | |||
| } | |||
| int op_type; | |||
| if (kernel == nullptr) { | |||
| op_type = static_cast<int>(primitive->value_type()); | |||
| } else { | |||
| op_type = static_cast<int>(kernel->type()); | |||
| } | |||
| if (op_type > PrimitiveType_MAX || op_type <= PrimitiveType_MIN) { | |||
| return nullptr; | |||
| } | |||
| int op_type = primitive->value_type(); | |||
| if (op_type == schema::PrimitiveType_Custom) { | |||
| return GetCustomKernelInterface(primitive); | |||
| return GetCustomKernelInterface(primitive, kernel); | |||
| } | |||
| std::unique_lock<std::mutex> lock(mutex_); | |||
| auto kernel = GetCacheInterface(provider, op_type); | |||
| if (kernel != nullptr) { | |||
| return kernel; | |||
| auto kernel_interface = GetCacheInterface(provider, op_type); | |||
| if (kernel_interface != nullptr) { | |||
| return kernel_interface; | |||
| } | |||
| auto iter = kernel_creators_.find(provider); | |||
| if (iter == kernel_creators_.end()) { | |||
| return nullptr; | |||
| } | |||
| if (op_type > PrimitiveType_MAX || op_type <= PrimitiveType_MIN) { | |||
| return nullptr; | |||
| } | |||
| auto creator = iter->second[op_type]; | |||
| if (creator != nullptr) { | |||
| kernel = creator(); | |||
| kernel_interfaces_[provider][op_type] = kernel; | |||
| return kernel; | |||
| kernel_interface = creator(); | |||
| kernel_interfaces_[provider][op_type] = kernel_interface; | |||
| return kernel_interface; | |||
| } | |||
| return nullptr; | |||
| } | |||
| Status KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) { | |||
| Status KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) { | |||
| if (op_type <= PrimitiveType_MIN || op_type > PrimitiveType_MAX) { | |||
| MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << PrimitiveType_MAX; | |||
| return kLiteParamInvalid; | |||
| @@ -142,6 +171,10 @@ Status KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, Ke | |||
| std::unique_lock<std::mutex> lock(mutex_); | |||
| auto iter = kernel_creators_.find(provider); | |||
| if (iter == kernel_creators_.end()) { | |||
| if (kernel_creators_.size() >= kMaxProviderNum) { | |||
| MS_LOG(ERROR) << "register too many provider!"; | |||
| return kLiteError; | |||
| } | |||
| kernel_creators_[provider] = | |||
| reinterpret_cast<KernelInterfaceCreator *>(calloc(kMaxKernelNum, sizeof(KernelInterfaceCreator))); | |||
| if (kernel_creators_[provider] == nullptr) { | |||
| @@ -35,9 +35,11 @@ class KernelInterfaceRegistry { | |||
| } | |||
| std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider, | |||
| const schema::Primitive *primitive); | |||
| Status CustomReg(const std::string &provider, const std::string &op_type, registry::KernelInterfaceCreator creator); | |||
| Status Reg(const std::string &provider, int op_type, registry::KernelInterfaceCreator creator); | |||
| const schema::Primitive *primitive, | |||
| const kernel::Kernel *kernel); | |||
| Status CustomReg(const std::string &provider, const std::string &op_type, | |||
| const registry::KernelInterfaceCreator creator); | |||
| Status Reg(const std::string &provider, int op_type, const registry::KernelInterfaceCreator creator); | |||
| virtual ~KernelInterfaceRegistry(); | |||
| private: | |||
| @@ -45,7 +47,8 @@ class KernelInterfaceRegistry { | |||
| std::shared_ptr<kernel::KernelInterface> GetCacheInterface(const std::string &provider, int op_type); | |||
| std::shared_ptr<kernel::KernelInterface> GetCustomCacheInterface(const std::string &provider, | |||
| const std::string &type); | |||
| std::shared_ptr<kernel::KernelInterface> GetCustomKernelInterface(const schema::Primitive *primitive); | |||
| std::shared_ptr<kernel::KernelInterface> GetCustomKernelInterface(const schema::Primitive *primitive, | |||
| const kernel::Kernel *kernel); | |||
| std::mutex mutex_; | |||
| // key: provider | |||
| @@ -23,7 +23,7 @@ | |||
| namespace mindspore { | |||
| namespace registry { | |||
| Status RegisterKernel::RegCustomKernel(const std::vector<char> &arch, const std::vector<char> &provider, | |||
| DataType data_type, const std::vector<char> &type, CreateKernel creator) { | |||
| DataType data_type, const std::vector<char> &type, const CreateKernel creator) { | |||
| #ifndef CUSTOM_KERNEL_REGISTRY_CLIP | |||
| return RegistryKernelImpl::GetInstance()->RegCustomKernel(CharToString(arch), CharToString(provider), data_type, | |||
| CharToString(type), creator); | |||
| @@ -34,7 +34,7 @@ Status RegisterKernel::RegCustomKernel(const std::vector<char> &arch, const std: | |||
| } | |||
| Status RegisterKernel::RegKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type, | |||
| int op_type, CreateKernel creator) { | |||
| int op_type, const CreateKernel creator) { | |||
| #ifndef CUSTOM_KERNEL_REGISTRY_CLIP | |||
| return RegistryKernelImpl::GetInstance()->RegKernel(CharToString(arch), CharToString(provider), data_type, op_type, | |||
| creator); | |||
| @@ -25,15 +25,14 @@ using mindspore::schema::PrimitiveType_MAX; | |||
| using mindspore::schema::PrimitiveType_MIN; | |||
| namespace mindspore::registry { | |||
| namespace { | |||
| static const auto kKernelMaxNum = | |||
| (static_cast<int>(DataType::kNumberTypeEnd) - static_cast<int>(DataType::kNumberTypeBegin) - 1) * | |||
| (PrimitiveType_MAX - PrimitiveType_MIN); | |||
| static const auto kOpTypeLen = PrimitiveType_MAX - PrimitiveType_MIN + 1; | |||
| static const auto kDataTypeLen = | |||
| static_cast<int>(DataType::kNumberTypeEnd) - static_cast<int>(DataType::kNumberTypeBegin) - 1; | |||
| static const auto kOpTypeLen = PrimitiveType_MAX - PrimitiveType_MIN; | |||
| } // namespace | |||
| int RegistryKernelImpl::GetFuncIndex(const KernelDesc &desc) { | |||
| static const auto kKernelMaxNum = kOpTypeLen * kDataTypeLen; | |||
| static constexpr auto kMaxProviderNum = 10; | |||
| static constexpr auto kMaxArchPerProviderNum = 10; | |||
| static constexpr auto kMaxCustomTypeNum = 200; | |||
| int GetFuncIndex(const KernelDesc &desc) { | |||
| if (desc.data_type >= DataType::kNumberTypeEnd) { | |||
| return -1; | |||
| } | |||
| @@ -47,14 +46,36 @@ int RegistryKernelImpl::GetFuncIndex(const KernelDesc &desc) { | |||
| } | |||
| return index; | |||
| } | |||
| } // namespace | |||
| Status RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, | |||
| const std::string &type, CreateKernel creator) { | |||
| if (data_type >= DataType::kNumberTypeEnd) { | |||
| const std::string &type, const CreateKernel creator) { | |||
| int data_type_index = static_cast<int>(data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1; | |||
| if (data_type_index < 0 || data_type_index >= kDataTypeLen) { | |||
| MS_LOG(ERROR) << "invalid data_type: " << static_cast<int>(data_type) << "!provider: " << provider; | |||
| return kLiteError; | |||
| } | |||
| std::unique_lock<std::mutex> lock(lock_); | |||
| auto provider_iter = custom_kernel_creators_.find(provider); | |||
| if (provider_iter == custom_kernel_creators_.end() && custom_kernel_creators_.size() >= kMaxProviderNum) { | |||
| MS_LOG(ERROR) << "register too many provider!"; | |||
| return kLiteError; | |||
| } | |||
| if (provider_iter != custom_kernel_creators_.end()) { | |||
| auto arch_iter = provider_iter->second.find(arch); | |||
| if (arch_iter == provider_iter->second.end()) { | |||
| if (provider_iter->second.size() >= kMaxArchPerProviderNum) { | |||
| MS_LOG(ERROR) << "register too many arch!"; | |||
| return kLiteError; | |||
| } | |||
| } else { | |||
| auto type_iter = arch_iter->second.find(type); | |||
| if (type_iter == arch_iter->second.end() && arch_iter->second.size() >= kMaxCustomTypeNum) { | |||
| MS_LOG(ERROR) << "register too many type!"; | |||
| return kLiteError; | |||
| } | |||
| } | |||
| } | |||
| if (custom_kernel_creators_[provider][arch][type] == nullptr) { | |||
| custom_kernel_creators_[provider][arch][type] = | |||
| reinterpret_cast<CreateKernel *>(calloc(kDataTypeLen, sizeof(CreateKernel))); | |||
| @@ -64,20 +85,30 @@ Status RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::s | |||
| } | |||
| } | |||
| int data_type_index = static_cast<int>(data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1; | |||
| if (data_type_index < 0 || data_type_index >= kDataTypeLen) { | |||
| MS_LOG(ERROR) << "invalid data_type: " << static_cast<int>(data_type) << "!provider: " << provider; | |||
| return kLiteError; | |||
| } | |||
| custom_kernel_creators_[provider][arch][type][data_type_index] = creator; | |||
| return kSuccess; | |||
| } | |||
| Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type, | |||
| registry::CreateKernel creator) { | |||
| const registry::CreateKernel creator) { | |||
| if (type <= static_cast<int>(PrimitiveType_MIN) || type > static_cast<int>(PrimitiveType_MAX)) { | |||
| MS_LOG(ERROR) << "Invalid op type : " << type; | |||
| return kLiteParamInvalid; | |||
| } | |||
| KernelDesc desc = {data_type, type, arch, provider}; | |||
| int index = GetFuncIndex(desc); | |||
| if (index < 0) { | |||
| MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << static_cast<int>(data_type) << ",op type " | |||
| << type; | |||
| return kLiteError; | |||
| } | |||
| std::unique_lock<std::mutex> lock(lock_); | |||
| auto iter = kernel_creators_.find(provider); | |||
| if (iter == kernel_creators_.end()) { | |||
| if (kernel_creators_.size() >= kMaxProviderNum) { | |||
| MS_LOG(ERROR) << "register too many provider!"; | |||
| return kLiteError; | |||
| } | |||
| kernel_creators_[provider][arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel))); | |||
| if (kernel_creators_[provider][arch] == nullptr) { | |||
| MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch; | |||
| @@ -86,6 +117,10 @@ Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string | |||
| } else { | |||
| auto iter_arch = iter->second.find(arch); | |||
| if (iter_arch == iter->second.end()) { | |||
| if (iter->second.size() >= kMaxArchPerProviderNum) { | |||
| MS_LOG(ERROR) << "register too many arch!"; | |||
| return kLiteError; | |||
| } | |||
| iter->second[arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel))); | |||
| if (iter->second[arch] == nullptr) { | |||
| MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch; | |||
| @@ -94,14 +129,6 @@ Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string | |||
| } | |||
| } | |||
| KernelDesc desc = {data_type, type, arch, provider}; | |||
| int index = GetFuncIndex(desc); | |||
| if (index < 0) { | |||
| MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << static_cast<int>(data_type) << ",op type " | |||
| << type; | |||
| return kLiteError; | |||
| } | |||
| kernel_creators_[provider][arch][index] = creator; | |||
| return kSuccess; | |||
| } | |||
| @@ -109,11 +136,11 @@ Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string | |||
| registry::CreateKernel RegistryKernelImpl::GetCustomKernelCreator(const schema::Primitive *primitive, | |||
| KernelDesc *desc) { | |||
| int data_type_index = static_cast<int>(desc->data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1; | |||
| if (data_type_index < 0 || data_type_index >= kDataTypeLen) { | |||
| if (data_type_index < 0 || desc->data_type >= DataType::kNumberTypeEnd) { | |||
| return nullptr; | |||
| } | |||
| auto param = primitive->value_as_Custom(); | |||
| if (param == nullptr) { | |||
| if (param == nullptr || param->type() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto custom_type = param->type()->str(); | |||
| @@ -37,10 +37,10 @@ class RegistryKernelImpl { | |||
| } | |||
| Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, | |||
| const std::string &type, registry::CreateKernel creator); | |||
| const std::string &type, const registry::CreateKernel creator); | |||
| Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type, | |||
| registry::CreateKernel creator); | |||
| const registry::CreateKernel creator); | |||
| virtual registry::CreateKernel GetProviderCreator(const schema::Primitive *primitive, registry::KernelDesc *desc); | |||
| @@ -60,7 +60,6 @@ class RegistryKernelImpl { | |||
| std::mutex lock_; | |||
| registry::CreateKernel GetCustomKernelCreator(const schema::Primitive *primitive, registry::KernelDesc *desc); | |||
| int GetFuncIndex(const registry::KernelDesc &desc); | |||
| }; | |||
| } // namespace mindspore::registry | |||
| @@ -22,7 +22,8 @@ | |||
| namespace mindspore { | |||
| namespace registry { | |||
| Status RegisterKernelInterface::Reg(const std::vector<char> &provider, int op_type, KernelInterfaceCreator creator) { | |||
| Status RegisterKernelInterface::Reg(const std::vector<char> &provider, int op_type, | |||
| const KernelInterfaceCreator creator) { | |||
| #ifndef CUSTOM_KERNEL_REGISTRY_CLIP | |||
| return KernelInterfaceRegistry::Instance()->Reg(CharToString(provider), op_type, creator); | |||
| #else | |||
| @@ -32,7 +33,7 @@ Status RegisterKernelInterface::Reg(const std::vector<char> &provider, int op_ty | |||
| } | |||
| Status RegisterKernelInterface::CustomReg(const std::vector<char> &provider, const std::vector<char> &op_type, | |||
| KernelInterfaceCreator creator) { | |||
| const KernelInterfaceCreator creator) { | |||
| #ifndef CUSTOM_KERNEL_REGISTRY_CLIP | |||
| return KernelInterfaceRegistry::Instance()->CustomReg(CharToString(provider), CharToString(op_type), creator); | |||
| #else | |||
| @@ -41,10 +42,11 @@ Status RegisterKernelInterface::CustomReg(const std::vector<char> &provider, con | |||
| #endif | |||
| } | |||
| std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface( | |||
| const std::vector<char> &provider, const schema::Primitive *primitive) { | |||
| std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(const std::vector<char> &provider, | |||
| const schema::Primitive *primitive, | |||
| const kernel::Kernel *kernel) { | |||
| #ifndef CUSTOM_KERNEL_REGISTRY_CLIP | |||
| return KernelInterfaceRegistry::Instance()->GetKernelInterface(CharToString(provider), primitive); | |||
| return KernelInterfaceRegistry::Instance()->GetKernelInterface(CharToString(provider), primitive, kernel); | |||
| #else | |||
| MS_LOG(ERROR) << unsupport_custom_kernel_register_log; | |||
| return nullptr; | |||
| @@ -34,23 +34,33 @@ namespace mindspore { | |||
| namespace lite { | |||
| #ifndef CUSTOM_KERNEL_REGISTRY_CLIP | |||
| int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const void *primitive, std::set<std::string> &&providers, int schema_version) { | |||
| if (primitive == nullptr) { | |||
| const void *primitive, std::set<std::string> &&providers, int schema_version, | |||
| const kernel::Kernel *kernel) { | |||
| if (primitive == nullptr && kernel == nullptr) { | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| std::shared_ptr<kernel::KernelInterface> kernel_interface = nullptr; | |||
| if (IsCustomNode(primitive, schema_version)) { | |||
| kernel_interface = | |||
| registry::RegisterKernelInterface::GetKernelInterface("", static_cast<const schema::Primitive *>(primitive)); | |||
| bool is_custom_node = false; | |||
| if (kernel == nullptr) { | |||
| if (IsCustomNode(primitive, schema_version)) { | |||
| is_custom_node = true; | |||
| } | |||
| } else if (kernel->type() == schema::PrimitiveType_Custom) { | |||
| is_custom_node = true; | |||
| } | |||
| if (is_custom_node) { | |||
| kernel_interface = registry::RegisterKernelInterface::GetKernelInterface( | |||
| "", static_cast<const schema::Primitive *>(primitive), kernel); | |||
| } else { | |||
| for (auto &&provider : providers) { | |||
| kernel_interface = registry::RegisterKernelInterface::GetKernelInterface( | |||
| provider, static_cast<const schema::Primitive *>(primitive)); | |||
| provider, static_cast<const schema::Primitive *>(primitive), kernel); | |||
| if (kernel_interface != nullptr) { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (kernel_interface == nullptr) { | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| @@ -26,13 +26,15 @@ | |||
| #include "src/tensor.h" | |||
| #include "nnacl/tensor_c.h" | |||
| #include "nnacl/infer/infer.h" | |||
| #include "include/api/kernel.h" | |||
| namespace mindspore::lite { | |||
| int KernelInferShape(const std::vector<lite::Tensor *> &tensors_in, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *parameter); | |||
| #ifndef CUSTOM_KERNEL_REGISTRY_CLIP | |||
| int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const void *primitive, std::set<std::string> &&providers, int schema_version); | |||
| const void *primitive, std::set<std::string> &&providers, int schema_version, | |||
| const kernel::Kernel *kernel = nullptr); | |||
| #endif | |||
| class InferManager { | |||
| public: | |||
| @@ -428,7 +428,7 @@ int TryFusionConvScaleWeight(LiteKernel *conv_kernel, LiteKernel *scale_kernel) | |||
| MS_ASSERT(conv_kernel); | |||
| MS_ASSERT(scale_kernel); | |||
| auto *scale_param = | |||
| reinterpret_cast<ScaleParameter *>(reinterpret_cast<OpenCLKernel *>(scale_kernel)->GetParameter()); | |||
| reinterpret_cast<ScaleParameter *>(reinterpret_cast<OpenCLKernel *>(scale_kernel->kernel())->GetParameter()); | |||
| MS_ASSERT(scale_param); | |||
| MS_ASSERT(conv_kernel->in_tensors().size() >= INPUT_TENSOR_SIZE_2); | |||
| auto *filter = conv_kernel->in_tensors().at(1); | |||
| @@ -1373,6 +1373,9 @@ kernel::LiteKernel *Scheduler::ScheduleNodeToKernel(const lite::Model::Node *src | |||
| SetKernelTensorDataType(kernel); | |||
| kernel->set_name(src_node->name_); | |||
| if (kernel->kernel() != nullptr) { | |||
| kernel->kernel()->SetConfig(config_info_); | |||
| } | |||
| return kernel; | |||
| } | |||
| @@ -59,6 +59,9 @@ class Scheduler { | |||
| ~Scheduler() = default; | |||
| int Schedule(std::vector<kernel::LiteKernel *> *dst_kernels); | |||
| void SetupSchedulerCb(std::unique_ptr<SchedulerCb> cb) { sched_cb_ = std::move(cb); } | |||
| void SetConfig(const std::map<std::string, std::map<std::string, std::string>> *config_info) { | |||
| config_info_ = config_info; | |||
| } | |||
| private: | |||
| int SchedulePreProcess(); | |||
| @@ -165,6 +168,7 @@ class Scheduler { | |||
| #endif | |||
| int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR; | |||
| std::map<std::string, TypeId> *execution_plan_ = nullptr; | |||
| const std::map<std::string, std::map<std::string, std::string>> *config_info_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -92,7 +92,7 @@ int SubGraphKernel::ReSize() { | |||
| int ret; | |||
| #ifndef CUSTOM_KERNEL_REGISTRY_CLIP | |||
| ret = lite::KernelInferShape(inputs, outputs, kernel->kernel()->primitive(), kernel->Context()->GetProviders(), | |||
| schema_version_); | |||
| schema_version_, kernel->kernel()); | |||
| if (ret == lite::RET_NOT_SUPPORT) { | |||
| #endif | |||
| auto parameter = kernel->op_parameter(); | |||
| @@ -51,10 +51,10 @@ TEST_F(MixDataTypeTest, Config1) { | |||
| std::string filename = "MixDataTypeTestConfig"; | |||
| std::string sectionname = "execution_plan"; | |||
| std::map<std::string, std::string> config_info; | |||
| ret = lite::GetSectionInfoFromConfigFile(filename, sectionname, &config_info); | |||
| std::map<std::string, std::map<std::string, std::string>> configs; | |||
| ret = lite::GetAllSectionInfoFromConfigFile(filename, &configs); | |||
| ASSERT_EQ(ret, 0); | |||
| std::map<std::string, std::string> config_info = configs[sectionname]; | |||
| ASSERT_EQ(config_info.size(), 2); | |||
| auto info0 = config_info.at("op1"); | |||