/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_CXX_API_FACTORY_H #define MINDSPORE_CCSRC_CXX_API_FACTORY_H #include #include #include #include #include #include #include "utils/utils.h" namespace mindspore { inline std::string g_device_target = "Default"; template class Factory { using U = std::function()>; public: Factory(const Factory &) = delete; void operator=(const Factory &) = delete; static Factory &Instance() { static Factory instance; return instance; } void Register(const std::string &device_name, U &&creator) { if (creators_.find(device_name) == creators_.end()) { (void)creators_.emplace(device_name, creator); } } bool CheckModelSupport(const std::string &device_name) { return std::any_of(creators_.begin(), creators_.end(), [&device_name](const std::pair &item) { return item.first == device_name; }); } std::shared_ptr Create(const std::string &device_name) { auto iter = creators_.find(device_name); if (creators_.end() != iter) { MS_EXCEPTION_IF_NULL(iter->second); return (iter->second)(); } MS_LOG(ERROR) << "Unsupported device target " << device_name; return nullptr; } private: Factory() = default; ~Factory() = default; std::map creators_; }; template class Registrar { using U = std::function()>; public: Registrar(const std::string &device_name, U creator) { Factory::Instance().Register(device_name, std::move(creator)); } ~Registrar() = default; }; #define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \ static const Registrar g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \ #DEVICE_NAME, []() { return std::make_shared(); }); } // namespace mindspore #endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H