Merge pull request !27495 from 徐永飞/mastertags/v1.6.0
| @@ -28,6 +28,7 @@ enum DeviceType { | |||
| kCPU = 0, | |||
| kGPU, | |||
| kKirinNPU, | |||
| kAscend, | |||
| kAscend910, | |||
| kAscend310, | |||
| // add new type here | |||
| @@ -287,34 +288,14 @@ void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) { | |||
| } | |||
| std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } | |||
| /// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend910. This option is | |||
| /// invalid for MindSpore Lite. | |||
| class MS_API Ascend910DeviceInfo : public DeviceInfoContext { | |||
| public: | |||
| /// \brief Get the type of this DeviceInfoContext. | |||
| /// | |||
| /// \return Type of this DeviceInfoContext. | |||
| enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; }; | |||
| /// \brief Set device id. | |||
| /// | |||
| /// \param[in] device_id The device id. | |||
| void SetDeviceID(uint32_t device_id); | |||
| /// \brief Get the device id. | |||
| /// | |||
| /// \return The device id. | |||
| uint32_t GetDeviceID() const; | |||
| }; | |||
| /// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is | |||
| /// invalid for MindSpore Lite. | |||
| class MS_API Ascend310DeviceInfo : public DeviceInfoContext { | |||
| class MS_API AscendDeviceInfo : public DeviceInfoContext { | |||
| public: | |||
| /// \brief Get the type of this DeviceInfoContext. | |||
| /// | |||
| /// \return Type of this DeviceInfoContext. | |||
| enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; }; | |||
| enum DeviceType GetDeviceType() const override { return DeviceType::kAscend; }; | |||
| /// \brief Set device id. | |||
| /// | |||
| @@ -447,45 +428,48 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext { | |||
| std::vector<char> GetBufferOptimizeModeChar() const; | |||
| }; | |||
| void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) { | |||
| using Ascend310DeviceInfo = AscendDeviceInfo; | |||
| using Ascend910DeviceInfo = AscendDeviceInfo; | |||
| void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) { | |||
| SetInsertOpConfigPath(StringToChar(cfg_path)); | |||
| } | |||
| std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); } | |||
| std::string AscendDeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); } | |||
| void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); } | |||
| std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); } | |||
| void AscendDeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); } | |||
| std::string AscendDeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); } | |||
| void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); } | |||
| std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); } | |||
| void AscendDeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); } | |||
| std::string AscendDeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); } | |||
| std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); } | |||
| std::string AscendDeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); } | |||
| void Ascend310DeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) { | |||
| void AscendDeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) { | |||
| SetDynamicImageSize(StringToChar(dynamic_image_size)); | |||
| } | |||
| std::string Ascend310DeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); } | |||
| std::string AscendDeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); } | |||
| void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) { | |||
| void AscendDeviceInfo::SetPrecisionMode(const std::string &precision_mode) { | |||
| SetPrecisionMode(StringToChar(precision_mode)); | |||
| } | |||
| std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } | |||
| std::string AscendDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } | |||
| void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) { | |||
| void AscendDeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) { | |||
| SetOpSelectImplMode(StringToChar(op_select_impl_mode)); | |||
| } | |||
| std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); } | |||
| std::string AscendDeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); } | |||
| void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) { | |||
| void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) { | |||
| SetFusionSwitchConfigPath(StringToChar(cfg_path)); | |||
| } | |||
| std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const { | |||
| std::string AscendDeviceInfo::GetFusionSwitchConfigPath() const { | |||
| return CharToString(GetFusionSwitchConfigPathChar()); | |||
| } | |||
| void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) { | |||
| void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) { | |||
| SetBufferOptimizeMode(StringToChar(buffer_optimize_mode)); | |||
| } | |||
| std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); } | |||
| std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); } | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_INCLUDE_API_CONTEXT_H | |||
| @@ -193,7 +193,7 @@ class MS_API Model { | |||
| /// \brief Inference model. | |||
| /// | |||
| /// \param[in] device_type Device type,options are kGPU, kAscend910, etc. | |||
| /// \param[in] device_type Device type,options are kGPU, kAscend, kAscend910, etc. | |||
| /// \param[in] model_type The type of model file, options are ModelType::kMindIR, ModelType::kOM. | |||
| /// | |||
| /// \return Is supported or not. | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H | |||
| #define MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H | |||
| #include <string> | |||
| #include "acl/acl_base.h" | |||
| namespace mindspore { | |||
| static inline bool IsAscend910Soc() { | |||
| const char *soc_name_c = aclrtGetSocName(); | |||
| if (soc_name_c == nullptr) { | |||
| return false; | |||
| } | |||
| std::string soc_name(soc_name_c); | |||
| if (soc_name.find("910") == std::string::npos) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| static inline bool IsAscendNo910Soc() { | |||
| const char *soc_name_c = aclrtGetSocName(); | |||
| if (soc_name_c == nullptr) { | |||
| return false; | |||
| } | |||
| std::string soc_name(soc_name_c); | |||
| if (soc_name.find("910") != std::string::npos) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H | |||
| @@ -175,55 +175,46 @@ std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const { | |||
| return StringToChar(ref); | |||
| } | |||
| void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| data_->params[kModelOptionAscend910DeviceID] = device_id; | |||
| } | |||
| uint32_t Ascend910DeviceInfo::GetDeviceID() const { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| return GetValue<uint32_t>(data_, kModelOptionAscend910DeviceID); | |||
| } | |||
| void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) { | |||
| void AscendDeviceInfo::SetDeviceID(uint32_t device_id) { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| data_->params[kModelOptionAscend310DeviceID] = device_id; | |||
| } | |||
| uint32_t Ascend310DeviceInfo::GetDeviceID() const { | |||
| uint32_t AscendDeviceInfo::GetDeviceID() const { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID); | |||
| } | |||
| void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) { | |||
| void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path); | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetInsertOpConfigPathChar() const { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InsertOpCfgPath); | |||
| return StringToChar(ref); | |||
| } | |||
| void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) { | |||
| void AscendDeviceInfo::SetInputFormat(const std::vector<char> &format) { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| data_->params[kModelOptionAscend310InputFormat] = CharToString(format); | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetInputFormatChar() const { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputFormat); | |||
| return StringToChar(ref); | |||
| } | |||
| void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) { | |||
| void AscendDeviceInfo::SetInputShape(const std::vector<char> &shape) { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| data_->params[kModelOptionAscend310InputShape] = CharToString(shape); | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetInputShapeChar() const { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputShape); | |||
| return StringToChar(ref); | |||
| } | |||
| void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) { | |||
| void AscendDeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| std::string batchs = ""; | |||
| for (size_t i = 0; i < dynamic_batch_size.size(); ++i) { | |||
| @@ -234,69 +225,69 @@ void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic | |||
| } | |||
| data_->params[kModelOptionAscend310DynamicBatchSize] = batchs; | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetDynamicBatchSizeChar() const { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DynamicBatchSize); | |||
| return StringToChar(ref); | |||
| } | |||
| void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) { return; } | |||
| void AscendDeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) { return; } | |||
| std::vector<char> Ascend310DeviceInfo::GetDynamicImageSizeChar() const { return std::vector<char>(); } | |||
| std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const { return std::vector<char>(); } | |||
| void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) { | |||
| void AscendDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode); | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetPrecisionModeChar() const { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310PrecisionMode); | |||
| return StringToChar(ref); | |||
| } | |||
| void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) { | |||
| void AscendDeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode); | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetOpSelectImplModeChar() const { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310OpSelectImplMode); | |||
| return StringToChar(ref); | |||
| } | |||
| void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) { | |||
| void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path); | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| const std::string &ref = GetValue<std::string>(data_, KModelOptionAscend310FusionSwitchCfgPath); | |||
| return StringToChar(ref); | |||
| } | |||
| void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) { | |||
| void AscendDeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| data_->params[kModelOptionAscend310InputShapeMap] = shape; | |||
| } | |||
| std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const { | |||
| std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap); | |||
| } | |||
| void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) { | |||
| void AscendDeviceInfo::SetOutputType(enum DataType output_type) { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| data_->params[kModelOptionAscend310OutputType] = output_type; | |||
| } | |||
| enum DataType Ascend310DeviceInfo::GetOutputType() const { | |||
| enum DataType AscendDeviceInfo::GetOutputType() const { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType); | |||
| } | |||
| void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) { | |||
| void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode); | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetBufferOptimizeModeChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const { | |||
| MS_EXCEPTION_IF_NULL(data_); | |||
| const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310BufferOptimize); | |||
| return StringToChar(ref); | |||
| @@ -24,7 +24,31 @@ | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| inline std::string g_device_target = "Default"; | |||
| inline enum DeviceType g_device_target = kInvalidDeviceType; | |||
| static inline LogStream &operator<<(LogStream &stream, DeviceType device_type) { | |||
| switch (device_type) { | |||
| case kAscend: | |||
| stream << "Ascend"; | |||
| break; | |||
| case kAscend910: | |||
| stream << "Ascend910"; | |||
| break; | |||
| case kAscend310: | |||
| stream << "Ascend310"; | |||
| break; | |||
| case kGPU: | |||
| stream << "GPU"; | |||
| break; | |||
| case kCPU: | |||
| stream << "CPU"; | |||
| break; | |||
| default: | |||
| stream << "[InvalidDeviceType: " << static_cast<int>(device_type) << "]"; | |||
| break; | |||
| } | |||
| return stream; | |||
| } | |||
| template <class T> | |||
| class Factory { | |||
| @@ -39,32 +63,24 @@ class Factory { | |||
| return instance; | |||
| } | |||
| void Register(const std::string &device_name, U &&creator) { | |||
| if (creators_.find(device_name) == creators_.end()) { | |||
| (void)creators_.emplace(device_name, creator); | |||
| } | |||
| } | |||
| bool CheckModelSupport(const std::string &device_name) { | |||
| return std::any_of(creators_.begin(), creators_.end(), | |||
| [&device_name](const std::pair<std::string, U> &item) { return item.first == device_name; }); | |||
| } | |||
| void Register(U &&creator) { creators_.push_back(creator); } | |||
| std::shared_ptr<T> Create(const std::string &device_name) { | |||
| auto iter = creators_.find(device_name); | |||
| if (creators_.end() != iter) { | |||
| MS_EXCEPTION_IF_NULL(iter->second); | |||
| return (iter->second)(); | |||
| std::shared_ptr<T> Create(enum DeviceType device_type) { | |||
| for (auto &item : creators_) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| auto val = item(); | |||
| if (val->CheckDeviceSupport(device_type)) { | |||
| return val; | |||
| } | |||
| } | |||
| MS_LOG(ERROR) << "Unsupported device target " << device_name; | |||
| MS_LOG(WARNING) << "Unsupported device target " << device_type; | |||
| return nullptr; | |||
| } | |||
| private: | |||
| Factory() = default; | |||
| ~Factory() = default; | |||
| std::map<std::string, U> creators_; | |||
| std::vector<U> creators_; | |||
| }; | |||
| template <class T> | |||
| @@ -72,14 +88,12 @@ class Registrar { | |||
| using U = std::function<std::shared_ptr<T>()>; | |||
| public: | |||
| Registrar(const std::string &device_name, U creator) { | |||
| Factory<T>::Instance().Register(device_name, std::move(creator)); | |||
| } | |||
| explicit Registrar(U creator) { Factory<T>::Instance().Register(std::move(creator)); } | |||
| ~Registrar() = default; | |||
| }; | |||
| #define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \ | |||
| static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \ | |||
| #DEVICE_NAME, []() { return std::make_shared<DERIVE_CLASS>(); }); | |||
| #define API_FACTORY_REG(BASE_CLASS, DERIVE_CLASS) \ | |||
| static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_reg( \ | |||
| []() { return std::make_shared<DERIVE_CLASS>(); }); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H | |||
| @@ -18,9 +18,10 @@ | |||
| #include "cxx_api/model/acl/model_converter.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "mindspore/core/utils/convert_utils_base.h" | |||
| #include "cxx_api/acl_utils.h" | |||
| namespace mindspore { | |||
| API_FACTORY_REG(GraphCell::GraphImpl, Ascend310, AclGraphImpl); | |||
| API_FACTORY_REG(GraphCell::GraphImpl, AclGraphImpl); | |||
| AclGraphImpl::AclGraphImpl() | |||
| : init_flag_(false), | |||
| @@ -231,4 +232,12 @@ Status AclGraphImpl::ConvertToOM() { | |||
| MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType(); | |||
| return kMCFailed; | |||
| } | |||
| bool AclGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) { | |||
| // for Ascend, only support kAscend and kAscend310 | |||
| if (device_type != kAscend && device_type != kAscend310) { | |||
| return false; | |||
| } | |||
| return IsAscendNo910Soc(); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -37,6 +37,7 @@ class AclGraphImpl : public GraphCell::GraphImpl { | |||
| Status Load(uint32_t device_id) override; | |||
| std::vector<MSTensor> GetInputs() override; | |||
| std::vector<MSTensor> GetOutputs() override; | |||
| bool CheckDeviceSupport(mindspore::DeviceType device_type) override; | |||
| private: | |||
| Status ConvertToOM(); | |||
| @@ -18,6 +18,7 @@ | |||
| #include "include/api/context.h" | |||
| #include "cxx_api/factory.h" | |||
| #include "cxx_api/akg_kernel_register.h" | |||
| #include "cxx_api/acl_utils.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/context/context_extends.h" | |||
| #include "mindspore/core/base/base_ref_utils.h" | |||
| @@ -30,7 +31,7 @@ | |||
| #include "pybind11/pybind11.h" | |||
| namespace mindspore { | |||
| API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl); | |||
| API_FACTORY_REG(GraphCell::GraphImpl, AscendGraphImpl); | |||
| static constexpr const char *kHcclEnable = "MS_ENABLE_HCCL"; | |||
| static constexpr const char *kHcclGroupFile = "PARA_GROUP_FILE"; | |||
| @@ -382,6 +383,14 @@ std::shared_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::GetEnv | |||
| return acl_env; | |||
| } | |||
| bool AscendGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) { | |||
| // for Ascend, only support kAscend and kAscend910 | |||
| if (device_type != kAscend && device_type != kAscend910) { | |||
| return false; | |||
| } | |||
| return IsAscend910Soc(); | |||
| } | |||
| std::map<uint32_t, std::weak_ptr<AscendGraphImpl::MsEnvGuard>> AscendGraphImpl::MsEnvGuard::global_ms_env_; | |||
| std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_; | |||
| @@ -39,6 +39,7 @@ class AscendGraphImpl : public GraphCell::GraphImpl { | |||
| Status Load(uint32_t device_id) override; | |||
| std::vector<MSTensor> GetInputs() override; | |||
| std::vector<MSTensor> GetOutputs() override; | |||
| bool CheckDeviceSupport(mindspore::DeviceType device_type) override; | |||
| private: | |||
| class MsEnvGuard; | |||
| @@ -26,7 +26,7 @@ | |||
| #include "runtime/device/gpu/cuda_driver.h" | |||
| namespace mindspore { | |||
| API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl); | |||
| API_FACTORY_REG(GraphCell::GraphImpl, GPUGraphImpl); | |||
| GPUGraphImpl::GPUGraphImpl() | |||
| : session_impl_(nullptr), | |||
| @@ -291,4 +291,6 @@ std::vector<MSTensor> GPUGraphImpl::GetOutputs() { | |||
| } | |||
| return result; | |||
| } | |||
| bool GPUGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) { return device_type == kGPU; } | |||
| } // namespace mindspore | |||
| @@ -37,6 +37,8 @@ class GPUGraphImpl : public GraphCell::GraphImpl { | |||
| std::vector<MSTensor> GetInputs() override; | |||
| std::vector<MSTensor> GetOutputs() override; | |||
| bool CheckDeviceSupport(mindspore::DeviceType device_type) override; | |||
| private: | |||
| Status InitEnv(); | |||
| Status FinalizeEnv(); | |||
| @@ -42,6 +42,8 @@ class GraphCell::GraphImpl { | |||
| virtual std::vector<MSTensor> GetInputs() = 0; | |||
| virtual std::vector<MSTensor> GetOutputs() = 0; | |||
| virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0; | |||
| protected: | |||
| std::shared_ptr<Graph> graph_; | |||
| std::shared_ptr<Context> graph_context_; | |||
| @@ -21,7 +21,7 @@ | |||
| #include "include/api/context.h" | |||
| #include "cxx_api/factory.h" | |||
| #include "cxx_api/graph/acl/acl_env_guard.h" | |||
| #include "acl/acl_base.h" | |||
| #include "cxx_api/acl_utils.h" | |||
| namespace mindspore { | |||
| Status AclModel::Build() { | |||
| @@ -112,7 +112,7 @@ Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<s | |||
| if (model_context_ == nullptr) { | |||
| model_context_ = std::make_shared<Context>(); | |||
| model_context_->MutableDeviceInfo().emplace_back(std::make_shared<Ascend310DeviceInfo>()); | |||
| model_context_->MutableDeviceInfo().emplace_back(std::make_shared<AscendDeviceInfo>()); | |||
| } | |||
| std::string input_shape_option; | |||
| @@ -139,7 +139,7 @@ Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<s | |||
| MS_LOG(ERROR) << "Invalid model context, only single device info is supported."; | |||
| return kMCInvalidArgs; | |||
| } | |||
| auto ascend310_info = device_infos[0]->Cast<Ascend310DeviceInfo>(); | |||
| auto ascend310_info = device_infos[0]->Cast<AscendDeviceInfo>(); | |||
| MS_EXCEPTION_IF_NULL(ascend310_info); | |||
| ascend310_info->SetInputShape(input_shape_option); | |||
| auto graph_cell_bak = std::move(graph_cell_); | |||
| @@ -163,16 +163,15 @@ std::vector<MSTensor> AclModel::GetOutputs() { | |||
| return graph_cell_->GetOutputs(); | |||
| } | |||
| bool AclModel::CheckModelSupport(enum ModelType model_type) { | |||
| const char *soc_name_c = aclrtGetSocName(); | |||
| if (soc_name_c == nullptr) { | |||
| return false; | |||
| } | |||
| std::string soc_name(soc_name_c); | |||
| if (soc_name.find("910") != std::string::npos) { | |||
| bool AclModel::CheckDeviceSupport(mindspore::DeviceType device_type) { | |||
| // for Ascend, only support kAscend and kAscend310 | |||
| if (device_type != kAscend && device_type != kAscend310) { | |||
| return false; | |||
| } | |||
| return IsAscendNo910Soc(); | |||
| } | |||
| bool AclModel::CheckModelSupport(enum ModelType model_type) { | |||
| static const std::set<ModelType> kSupportedModelMap = {kMindIR, kOM}; | |||
| auto iter = kSupportedModelMap.find(model_type); | |||
| if (iter == kSupportedModelMap.end()) { | |||
| @@ -43,6 +43,7 @@ class AclModel : public ModelImpl { | |||
| std::vector<MSTensor> GetInputs() override; | |||
| std::vector<MSTensor> GetOutputs() override; | |||
| bool CheckDeviceSupport(mindspore::DeviceType device_type) override; | |||
| bool CheckModelSupport(enum ModelType model_type) override; | |||
| private: | |||
| @@ -29,7 +29,7 @@ | |||
| #include "cxx_api/model/acl/acl_vm/acl_vm.h" | |||
| namespace mindspore { | |||
| API_FACTORY_REG(ModelImpl, Ascend310, AclModelMulti); | |||
| API_FACTORY_REG(ModelImpl, AclModelMulti); | |||
| namespace { | |||
| std::map<DataType, size_t> kDtypeMap = { | |||
| @@ -33,7 +33,7 @@ AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) { | |||
| if (device_infos.size() != 1) { | |||
| return; | |||
| } | |||
| auto ascend310_info = device_infos[0]->Cast<Ascend310DeviceInfo>(); | |||
| auto ascend310_info = device_infos[0]->Cast<AscendDeviceInfo>(); | |||
| if (ascend310_info == nullptr) { | |||
| return; | |||
| } | |||
| @@ -20,19 +20,6 @@ | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| namespace { | |||
| std::string GetDeviceTypeString(enum DeviceType type) { | |||
| static const std::map<enum DeviceType, std::string> kDeviceTypeStrs = { | |||
| {kCPU, "CPU"}, {kGPU, "GPU"}, {kKirinNPU, "KirinGPU"}, {kAscend910, "Ascend910"}, {kAscend310, "Ascend310"}, | |||
| }; | |||
| auto iter = kDeviceTypeStrs.find(type); | |||
| if (iter != kDeviceTypeStrs.end()) { | |||
| return iter->second; | |||
| } | |||
| return "InvalidDeviceType" + std::to_string(static_cast<int>(type)); | |||
| } | |||
| } // namespace | |||
| Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_context, | |||
| const std::shared_ptr<TrainCfg> &) { | |||
| if (graph_cell.GetGraph() == nullptr) { | |||
| @@ -50,7 +37,7 @@ Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_ | |||
| return kMCInvalidInput; | |||
| } | |||
| std::string device_target = GetDeviceTypeString(device_info[0]->GetDeviceType()); | |||
| auto device_target = device_info[0]->GetDeviceType(); | |||
| impl_ = Factory<ModelImpl>::Instance().Create(device_target); | |||
| if (impl_ == nullptr) { | |||
| MS_LOG(ERROR) << "Create session type " << device_target << " failed"; | |||
| @@ -175,16 +162,10 @@ Model::Model() : impl_(nullptr) {} | |||
| Model::~Model() {} | |||
| bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) { | |||
| std::string device_type_str = GetDeviceTypeString(device_type); | |||
| if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type_str)) { | |||
| return false; | |||
| } | |||
| auto check_model = Factory<ModelImpl>::Instance().Create(device_type_str); | |||
| auto check_model = Factory<ModelImpl>::Instance().Create(device_type); | |||
| if (check_model == nullptr) { | |||
| return false; | |||
| } | |||
| return check_model->CheckModelSupport(model_type); | |||
| } | |||
| @@ -44,7 +44,8 @@ class ModelImpl { | |||
| virtual std::vector<MSTensor> GetInputs() = 0; | |||
| virtual std::vector<MSTensor> GetOutputs() = 0; | |||
| virtual bool CheckModelSupport(enum ModelType model_type) { return false; } | |||
| virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0; | |||
| virtual bool CheckModelSupport(enum ModelType model_type) = 0; | |||
| virtual Status Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs); | |||
| @@ -20,14 +20,13 @@ | |||
| #include "include/api/context.h" | |||
| #include "utils/ms_context.h" | |||
| #include "cxx_api/factory.h" | |||
| #if ENABLE_D | |||
| #include "cxx_api/acl_utils.h" | |||
| #endif | |||
| namespace mindspore { | |||
| // mindspore-serving check current package for version check with ModelImpl factory. | |||
| #if ENABLE_D | |||
| API_FACTORY_REG(ModelImpl, Ascend910, MsModel); | |||
| #elif ENABLE_GPU | |||
| API_FACTORY_REG(ModelImpl, GPU, MsModel); | |||
| #endif | |||
| API_FACTORY_REG(ModelImpl, MsModel); | |||
| static std::string GenerateShapeKey(const std::vector<std::vector<int64_t>> &dims) { | |||
| std::string shape_key; | |||
| @@ -171,18 +170,23 @@ uint32_t MsModel::GetDeviceID() const { | |||
| return 0; | |||
| } | |||
| bool MsModel::CheckModelSupport(enum ModelType model_type) { | |||
| bool MsModel::CheckDeviceSupport(enum DeviceType device_type) { | |||
| #if ENABLE_D | |||
| const char *soc_name_c = aclrtGetSocName(); | |||
| if (soc_name_c == nullptr) { | |||
| // for Ascend, only support kAscend or kAscend910 | |||
| if (device_type != kAscend && device_type != kAscend910) { | |||
| return false; | |||
| } | |||
| std::string soc_name(soc_name_c); | |||
| if (soc_name.find("910") == std::string::npos) { | |||
| return IsAscend910Soc(); | |||
| #else | |||
| // otherwise, only support GPU | |||
| if (device_type != kGPU) { | |||
| return false; | |||
| } | |||
| return true; | |||
| #endif | |||
| } | |||
| bool MsModel::CheckModelSupport(mindspore::ModelType model_type) { | |||
| static const std::set<ModelType> kSupportedModelMap = {kMindIR}; | |||
| auto iter = kSupportedModelMap.find(model_type); | |||
| if (iter == kSupportedModelMap.end()) { | |||
| @@ -44,6 +44,7 @@ class MsModel : public ModelImpl { | |||
| std::vector<MSTensor> GetInputs() override; | |||
| std::vector<MSTensor> GetOutputs() override; | |||
| bool CheckDeviceSupport(mindspore::DeviceType device_type) override; | |||
| bool CheckModelSupport(enum ModelType model_type) override; | |||
| private: | |||
| @@ -1,18 +1,18 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H | |||
| #define MINDSPORE_INCLUDE_API_CONTEXT_H | |||
| @@ -25,212 +25,431 @@ | |||
| namespace mindspore { | |||
| enum DeviceType { | |||
| kCPU = 0, | |||
| kGPU, | |||
| kKirinNPU, | |||
| kAscend910, | |||
| kAscend310, | |||
| // add new type here | |||
| kInvalidDeviceType = 100, | |||
| kCPU = 0, | |||
| kGPU, | |||
| kKirinNPU, | |||
| kAscend, | |||
| kAscend910, | |||
| kAscend310, | |||
| // add new type here | |||
| kInvalidDeviceType = 100, | |||
| }; | |||
| class Allocator; | |||
| class Delegate; | |||
| class DeviceInfoContext; | |||
| /// \brief Context is used to store environment variables during execution. | |||
| class MS_API Context { | |||
| public: | |||
| Context(); | |||
| ~Context() = default; | |||
| void SetThreadNum(int32_t thread_num); | |||
| int32_t GetThreadNum() const; | |||
| void SetAllocator(const std::shared_ptr<Allocator> &allocator); | |||
| std::shared_ptr<Allocator> GetAllocator() const; | |||
| std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo(); | |||
| private: | |||
| struct Data; | |||
| std::shared_ptr<Data> data_; | |||
| public: | |||
| struct Data; | |||
| Context(); | |||
| ~Context() = default; | |||
| /// \brief Set the number of threads at runtime. Only valid for Lite. | |||
| /// | |||
| /// \param[in] thread_num the number of threads at runtime. | |||
| void SetThreadNum(int32_t thread_num); | |||
| /// \brief Get the current thread number setting. Only valid for Lite. | |||
| /// | |||
| /// \return The current thread number setting. | |||
| int32_t GetThreadNum() const; | |||
| /// \brief Set the thread affinity to CPU cores. Only valid for Lite. | |||
| /// | |||
| /// \param[in] mode: 0: no affinities, 1: big cores first, 2: little cores first | |||
| void SetThreadAffinity(int mode); | |||
| /// \brief Get the thread affinity of CPU cores. Only valid for Lite. | |||
| /// | |||
| /// \return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first | |||
| int GetThreadAffinityMode() const; | |||
| /// \brief Set the thread lists to CPU cores. Only valid for Lite. | |||
| /// | |||
| /// \note If core_list and mode are set by SetThreadAffinity at the same time, the core_list is effective, but the | |||
| /// mode is not effective. | |||
| /// | |||
| /// \param[in] core_list: a vector of thread core lists. | |||
| void SetThreadAffinity(const std::vector<int> &core_list); | |||
| /// \brief Get the thread lists of CPU cores. Only valid for Lite. | |||
| /// | |||
| /// \return core_list: a vector of thread core lists. | |||
| std::vector<int32_t> GetThreadAffinityCoreList() const; | |||
| /// \brief Set the status whether to perform model inference or training in parallel. Only valid for Lite. | |||
| /// | |||
| /// \param[in] is_parallel: true, parallel; false, not in parallel. | |||
| void SetEnableParallel(bool is_parallel); | |||
| /// \brief Get the status whether to perform model inference or training in parallel. Only valid for Lite. | |||
| /// | |||
| /// \return Bool value that indicates whether in parallel. | |||
| bool GetEnableParallel() const; | |||
| /// \brief Set Delegate to access third-party AI framework. Only valid for Lite. | |||
| /// | |||
| /// \param[in] Pointer to the custom delegate. | |||
| void SetDelegate(const std::shared_ptr<Delegate> &delegate); | |||
| /// \brief Get the delegate of the third-party AI framework. Only valid for Lite. | |||
| /// | |||
| /// \return Pointer to the custom delegate. | |||
| std::shared_ptr<Delegate> GetDelegate() const; | |||
| /// \brief Get a mutable reference of DeviceInfoContext vector in this context. Only MindSpore Lite supports | |||
| /// heterogeneous scenarios with multiple members in the vector. | |||
| /// | |||
| /// \return Mutable reference of DeviceInfoContext vector in this context. | |||
| std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo(); | |||
| private: | |||
| std::shared_ptr<Data> data_; | |||
| }; | |||
| /// \brief DeviceInfoContext defines different device contexts. | |||
| class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> { | |||
| public: | |||
| struct Data; | |||
| DeviceInfoContext(); | |||
| virtual ~DeviceInfoContext() = default; | |||
| virtual enum DeviceType GetDeviceType() const = 0; | |||
| template <class T> | |||
| std::shared_ptr<T> Cast() { | |||
| static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type."); | |||
| if (GetDeviceType() != T().GetDeviceType()) { | |||
| return nullptr; | |||
| } | |||
| return std::static_pointer_cast<T>(shared_from_this()); | |||
| } | |||
| protected: | |||
| std::shared_ptr<Data> data_; | |||
| public: | |||
| struct Data; | |||
| DeviceInfoContext(); | |||
| virtual ~DeviceInfoContext() = default; | |||
| /// \brief Get the type of this DeviceInfoContext. | |||
| /// | |||
| /// \return Type of this DeviceInfoContext. | |||
| virtual enum DeviceType GetDeviceType() const = 0; | |||
| /// \brief A similar function to RTTI is provided when the -fno-rtti compilation option is turned on, which converts | |||
| /// DeviceInfoContext to a shared pointer of type T, and returns nullptr if the conversion fails. | |||
| /// | |||
| /// \param T Type | |||
| /// \return A pointer of type T after conversion. If the conversion fails, it will be nullptr. | |||
| template <class T> | |||
| std::shared_ptr<T> Cast() { | |||
| static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type."); | |||
| if (GetDeviceType() != T().GetDeviceType()) { | |||
| return nullptr; | |||
| } | |||
| return std::static_pointer_cast<T>(shared_from_this()); | |||
| } | |||
| /// \brief obtain provider's name | |||
| /// | |||
| /// \return provider's name. | |||
| std::string GetProvider() const; | |||
| /// \brief set provider's name. | |||
| /// | |||
| /// \param[in] provider define the provider's name. | |||
| void SetProvider(const std::string &provider); | |||
| /// \brief obtain provider's device type. | |||
| /// | |||
| /// \return provider's device type. | |||
| std::string GetProviderDevice() const; | |||
| /// \brief set provider's device type. | |||
| /// | |||
| /// \param[in] device define the provider's device type.EG: CPU. | |||
| void SetProviderDevice(const std::string &device); | |||
| /// \brief set memory allocator. | |||
| /// | |||
| /// \param[in] allocator define the memory allocator which can be defined by user. | |||
| void SetAllocator(const std::shared_ptr<Allocator> &allocator); | |||
| /// \brief obtain memory allocator. | |||
| /// | |||
| /// \return memory allocator. | |||
| std::shared_ptr<Allocator> GetAllocator() const; | |||
| protected: | |||
| std::shared_ptr<Data> data_; | |||
| }; | |||
| /// \brief Derived from DeviceInfoContext, The configuration of the model running on the CPU. This option is only valid | |||
| /// for MindSpore Lite. | |||
| class MS_API CPUDeviceInfo : public DeviceInfoContext { | |||
| public: | |||
| enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; }; | |||
| /// \brief Set the thread affinity to CPU cores. | |||
| /// | |||
| /// \param mode: 0: no affinities, 1: big cores first, 2: little cores first | |||
| void SetThreadAffinity(int mode); | |||
| int GetThreadAffinity() const; | |||
| void SetEnableFP16(bool is_fp16); | |||
| bool GetEnableFP16() const; | |||
| public: | |||
| /// \brief Get the type of this DeviceInfoContext. | |||
| /// | |||
| /// \return Type of this DeviceInfoContext. | |||
| enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; }; | |||
| /// \brief Set enables to perform the float16 inference | |||
| /// | |||
| /// \param[in] is_fp16 Enable float16 inference or not. | |||
| void SetEnableFP16(bool is_fp16); | |||
| /// \brief Get enables to perform the float16 inference | |||
| /// | |||
| /// \return Whether enable float16 inference. | |||
| bool GetEnableFP16() const; | |||
| }; | |||
| /// \brief Derived from DeviceInfoContext, The configuration of the model running on the NPU. This option is only valid | |||
| /// for MindSpore Lite. | |||
| class MS_API KirinNPUDeviceInfo : public DeviceInfoContext { | |||
| public: | |||
| enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; }; | |||
| void SetFrequency(int frequency); | |||
| int GetFrequency() const; | |||
| public: | |||
| /// \brief Get the type of this DeviceInfoContext. | |||
| /// | |||
| /// \return Type of this DeviceInfoContext. | |||
| enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; }; | |||
| /// \brief Set the NPU frequency. | |||
| /// | |||
| /// \param[in] frequency Can be set to 1 (low power consumption), 2 (balanced), 3 (high performance), 4 (extreme | |||
| /// performance), default as 3. | |||
| void SetFrequency(int frequency); | |||
| /// \brief Get the NPU frequency. | |||
| /// | |||
| /// \return NPU frequency | |||
| int GetFrequency() const; | |||
| }; | |||
| /// \brief Derived from DeviceInfoContext, The configuration of the model running on the GPU. | |||
| class MS_API GPUDeviceInfo : public DeviceInfoContext { | |||
| public: | |||
| enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; }; | |||
| void SetDeviceID(uint32_t device_id); | |||
| uint32_t GetDeviceID() const; | |||
| void SetGpuTrtInferMode(bool gpu_trt_infer_mode); | |||
| bool GetGpuTrtInferMode() const; | |||
| void SetEnableFP16(bool is_fp16); | |||
| bool GetEnableFP16() const; | |||
| public: | |||
| /// \brief Get the type of this DeviceInfoContext. | |||
| /// | |||
| /// \return Type of this DeviceInfoContext. | |||
| enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; }; | |||
| /// \brief Set device id. | |||
| /// | |||
| /// \param[in] device_id The device id. | |||
| void SetDeviceID(uint32_t device_id); | |||
| /// \brief Get the device id. | |||
| /// | |||
| /// \return The device id. | |||
| uint32_t GetDeviceID() const; | |||
| /// \brief Get the distribution rank id. | |||
| /// | |||
| /// \return The device id. | |||
| int GetRankID() const; | |||
| /// \brief Get the distribution group size. | |||
| /// | |||
| /// \return The device id. | |||
| int GetGroupSize() const; | |||
| /// \brief Set the precision mode. | |||
| /// | |||
| /// \param[in] precision_mode Optional "origin", "fp16". "origin" is set as default. | |||
| inline void SetPrecisionMode(const std::string &precision_mode); | |||
| /// \brief Get the precision mode. | |||
| /// | |||
| /// \return The precision mode. | |||
| inline std::string GetPrecisionMode() const; | |||
| /// \brief Set enables to perform the float16 inference | |||
| /// | |||
| /// \param[in] is_fp16 Enable float16 inference or not. | |||
| void SetEnableFP16(bool is_fp16); | |||
| /// \brief Get enables to perform the float16 inference | |||
| /// | |||
| /// \return Whether enable float16 inference. | |||
| bool GetEnableFP16() const; | |||
| private: | |||
| void SetPrecisionMode(const std::vector<char> &precision_mode); | |||
| std::vector<char> GetPrecisionModeChar() const; | |||
| }; | |||
| class MS_API Ascend910DeviceInfo : public DeviceInfoContext { | |||
| public: | |||
| enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; }; | |||
| void SetDeviceID(uint32_t device_id); | |||
| uint32_t GetDeviceID() const; | |||
| void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) { | |||
| SetPrecisionMode(StringToChar(precision_mode)); | |||
| } | |||
| std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } | |||
| /// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is | |||
| /// invalid for MindSpore Lite. | |||
| class MS_API AscendDeviceInfo : public DeviceInfoContext { | |||
| public: | |||
| /// \brief Get the type of this DeviceInfoContext. | |||
| /// | |||
| /// \return Type of this DeviceInfoContext. | |||
| enum DeviceType GetDeviceType() const override { return DeviceType::kAscend; }; | |||
| /// \brief Set device id. | |||
| /// | |||
| /// \param[in] device_id The device id. | |||
| void SetDeviceID(uint32_t device_id); | |||
| /// \brief Get the device id. | |||
| /// | |||
| /// \return The device id. | |||
| uint32_t GetDeviceID() const; | |||
| /// \brief Set AIPP configuration file path. | |||
| /// | |||
| /// \param[in] cfg_path AIPP configuration file path. | |||
| inline void SetInsertOpConfigPath(const std::string &cfg_path); | |||
| /// \brief Get AIPP configuration file path. | |||
| /// | |||
| /// \return AIPP configuration file path. | |||
| inline std::string GetInsertOpConfigPath() const; | |||
| /// \brief Set format of model inputs. | |||
| /// | |||
| /// \param[in] format Optional "NCHW", "NHWC", etc. | |||
| inline void SetInputFormat(const std::string &format); | |||
| /// \brief Get format of model inputs. | |||
| /// | |||
| /// \return The format of model inputs. | |||
| inline std::string GetInputFormat() const; | |||
| /// \brief Set shape of model inputs. | |||
| /// | |||
| /// \param[in] shape e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1". | |||
| inline void SetInputShape(const std::string &shape); | |||
| /// \brief Get shape of model inputs. | |||
| /// | |||
| /// \return The shape of model inputs. | |||
| inline std::string GetInputShape() const; | |||
| /// \brief Set shape of model inputs. | |||
| /// | |||
| /// \param[in] shape e.g. {{1, {1,2,3,4}}, {2, {4,3,2,1}}} means the first input shape 1,2,3,4 and the second input | |||
| /// shape 4,3,2,1. | |||
| void SetInputShapeMap(const std::map<int, std::vector<int>> &shape); | |||
| /// \brief Get shape of model inputs. | |||
| /// | |||
| /// \return The shape of model inputs. | |||
| std::map<int, std::vector<int>> GetInputShapeMap() const; | |||
| void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size); | |||
| inline std::string GetDynamicBatchSize() const; | |||
| /// \brief Set the dynamic image size of model inputs. | |||
| /// | |||
| /// \param[in] image size hw e.g. "66,88;32,64" means h1:66,w1:88; h2:32,w2:64. | |||
| inline void SetDynamicImageSize(const std::string &dynamic_image_size); | |||
| /// \brief Get dynamic image size of model inputs. | |||
| /// | |||
| /// \return The image size of model inputs. | |||
| inline std::string GetDynamicImageSize() const; | |||
| /// \brief Set type of model outputs. | |||
| /// | |||
| /// \param[in] output_type FP32, UINT8 or FP16, default as FP32. | |||
| void SetOutputType(enum DataType output_type); | |||
| /// \brief Get type of model outputs. | |||
| /// | |||
| /// \return The set type of model outputs. | |||
| enum DataType GetOutputType() const; | |||
| /// \brief Set precision mode of model. | |||
| /// | |||
| /// \param[in] precision_mode Optional "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" and | |||
| /// "allow_mix_precision", "force_fp16" is set as default | |||
| inline void SetPrecisionMode(const std::string &precision_mode); | |||
| /// \brief Get precision mode of model. | |||
| /// | |||
| /// \return The set type of model outputs | |||
| inline std::string GetPrecisionMode() const; | |||
| /// \brief Set op select implementation mode. | |||
| /// | |||
| /// \param[in] op_select_impl_mode Optional "high_performance" and "high_precision", "high_performance" is set as | |||
| /// default. | |||
| inline void SetOpSelectImplMode(const std::string &op_select_impl_mode); | |||
| /// \brief Get op select implementation mode. | |||
| /// | |||
| /// \return The set op select implementation mode. | |||
| inline std::string GetOpSelectImplMode() const; | |||
| inline void SetFusionSwitchConfigPath(const std::string &cfg_path); | |||
| inline std::string GetFusionSwitchConfigPath() const; | |||
| // Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize" | |||
| inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode); | |||
| inline std::string GetBufferOptimizeMode() const; | |||
| private: | |||
| void SetInsertOpConfigPath(const std::vector<char> &cfg_path); | |||
| std::vector<char> GetInsertOpConfigPathChar() const; | |||
| void SetInputFormat(const std::vector<char> &format); | |||
| std::vector<char> GetInputFormatChar() const; | |||
| void SetInputShape(const std::vector<char> &shape); | |||
| std::vector<char> GetInputShapeChar() const; | |||
| std::vector<char> GetDynamicBatchSizeChar() const; | |||
| void SetDynamicImageSize(const std::vector<char> &dynamic_image_size); | |||
| std::vector<char> GetDynamicImageSizeChar() const; | |||
| void SetPrecisionMode(const std::vector<char> &precision_mode); | |||
| std::vector<char> GetPrecisionModeChar() const; | |||
| void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode); | |||
| std::vector<char> GetOpSelectImplModeChar() const; | |||
| void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path); | |||
| std::vector<char> GetFusionSwitchConfigPathChar() const; | |||
| void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode); | |||
| std::vector<char> GetBufferOptimizeModeChar() const; | |||
| }; | |||
| class MS_API Ascend310DeviceInfo : public DeviceInfoContext { | |||
| public: | |||
| enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; }; | |||
| void SetDeviceID(uint32_t device_id); | |||
| uint32_t GetDeviceID() const; | |||
| inline void SetDumpConfigPath(const std::string &cfg_path); | |||
| inline std::string GetDumpConfigPath() const; | |||
| // aipp config file | |||
| inline void SetInsertOpConfigPath(const std::string &cfg_path); | |||
| inline std::string GetInsertOpConfigPath() const; | |||
| // nchw or nhwc | |||
| inline void SetInputFormat(const std::string &format); | |||
| inline std::string GetInputFormat() const; | |||
| using Ascend310DeviceInfo = AscendDeviceInfo; | |||
| using Ascend910DeviceInfo = AscendDeviceInfo; | |||
| // Mandatory while dynamic batch: e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1" | |||
| inline void SetInputShape(const std::string &shape); | |||
| inline std::string GetInputShape() const; | |||
| void SetInputShapeMap(const std::map<int, std::vector<int>> &shape); | |||
| std::map<int, std::vector<int>> GetInputShapeMap() const; | |||
| void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size); | |||
| inline std::string GetDynamicBatchSize() const; | |||
| // FP32, UINT8 or FP16, default as FP32 | |||
| void SetOutputType(enum DataType output_type); | |||
| enum DataType GetOutputType() const; | |||
| // "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype", default as "force_fp16" | |||
| inline void SetPrecisionMode(const std::string &precision_mode); | |||
| inline std::string GetPrecisionMode() const; | |||
| // Optional "high_performance" and "high_precision", "high_performance" is set as default | |||
| inline void SetOpSelectImplMode(const std::string &op_select_impl_mode); | |||
| inline std::string GetOpSelectImplMode() const; | |||
| inline void SetFusionSwitchConfigPath(const std::string &cfg_path); | |||
| inline std::string GetFusionSwitchConfigPath() const; | |||
| // Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize" | |||
| inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode); | |||
| inline std::string GetBufferOptimizeMode() const; | |||
| private: | |||
| void SetDumpConfigPath(const std::vector<char> &cfg_path); | |||
| std::vector<char> GetDumpConfigPathChar() const; | |||
| void SetInsertOpConfigPath(const std::vector<char> &cfg_path); | |||
| std::vector<char> GetInsertOpConfigPathChar() const; | |||
| void SetInputFormat(const std::vector<char> &format); | |||
| std::vector<char> GetInputFormatChar() const; | |||
| void SetInputShape(const std::vector<char> &shape); | |||
| std::vector<char> GetInputShapeChar() const; | |||
| std::vector<char> GetDynamicBatchSizeChar() const; | |||
| void SetPrecisionMode(const std::vector<char> &precision_mode); | |||
| std::vector<char> GetPrecisionModeChar() const; | |||
| void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode); | |||
| std::vector<char> GetOpSelectImplModeChar() const; | |||
| void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) { | |||
| SetInsertOpConfigPath(StringToChar(cfg_path)); | |||
| } | |||
| std::string AscendDeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); } | |||
| void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path); | |||
| std::vector<char> GetFusionSwitchConfigPathChar() const; | |||
| void AscendDeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); } | |||
| std::string AscendDeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); } | |||
| void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode); | |||
| std::vector<char> GetBufferOptimizeModeChar() const; | |||
| }; | |||
| void AscendDeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); } | |||
| std::string AscendDeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); } | |||
| void Ascend310DeviceInfo::SetDumpConfigPath(const std::string &cfg_path) { SetDumpConfigPath(StringToChar(cfg_path)); } | |||
| std::string Ascend310DeviceInfo::GetDumpConfigPath() const { return CharToString(GetDumpConfigPathChar()); } | |||
| std::string AscendDeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); } | |||
| void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) { | |||
| SetInsertOpConfigPath(StringToChar(cfg_path)); | |||
| void AscendDeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) { | |||
| SetDynamicImageSize(StringToChar(dynamic_image_size)); | |||
| } | |||
| std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); } | |||
| void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); } | |||
| std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); } | |||
| void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); } | |||
| std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); } | |||
| std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); } | |||
| std::string AscendDeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); } | |||
| void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) { | |||
| SetPrecisionMode(StringToChar(precision_mode)); | |||
| void AscendDeviceInfo::SetPrecisionMode(const std::string &precision_mode) { | |||
| SetPrecisionMode(StringToChar(precision_mode)); | |||
| } | |||
| std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } | |||
| std::string AscendDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } | |||
| void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) { | |||
| SetOpSelectImplMode(StringToChar(op_select_impl_mode)); | |||
| void AscendDeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) { | |||
| SetOpSelectImplMode(StringToChar(op_select_impl_mode)); | |||
| } | |||
| std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); } | |||
| std::string AscendDeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); } | |||
| void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) { | |||
| SetFusionSwitchConfigPath(StringToChar(cfg_path)); | |||
| void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) { | |||
| SetFusionSwitchConfigPath(StringToChar(cfg_path)); | |||
| } | |||
| std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const { | |||
| return CharToString(GetFusionSwitchConfigPathChar()); | |||
| std::string AscendDeviceInfo::GetFusionSwitchConfigPath() const { | |||
| return CharToString(GetFusionSwitchConfigPathChar()); | |||
| } | |||
| void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) { | |||
| SetBufferOptimizeMode(StringToChar(buffer_optimize_mode)); | |||
| void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) { | |||
| SetBufferOptimizeMode(StringToChar(buffer_optimize_mode)); | |||
| } | |||
| std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); } | |||
| std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); } | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_INCLUDE_API_CONTEXT_H | |||
| @@ -317,13 +317,7 @@ std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const { | |||
| return ret; | |||
| } | |||
| void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) { MS_LOG(ERROR) << "Unsupported Feature."; } | |||
| uint32_t Ascend910DeviceInfo::GetDeviceID() const { | |||
| MS_LOG(ERROR) << "Unsupported Feature."; | |||
| return 0; | |||
| } | |||
| void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) { | |||
| void AscendDeviceInfo::SetDeviceID(uint32_t device_id) { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return; | |||
| @@ -331,7 +325,7 @@ void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) { | |||
| data_->params[kModelOptionAscend310DeviceID] = device_id; | |||
| } | |||
| uint32_t Ascend310DeviceInfo::GetDeviceID() const { | |||
| uint32_t AscendDeviceInfo::GetDeviceID() const { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return 0; | |||
| @@ -339,14 +333,14 @@ uint32_t Ascend310DeviceInfo::GetDeviceID() const { | |||
| return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID); | |||
| } | |||
| void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) { | |||
| void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return; | |||
| } | |||
| data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path); | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetInsertOpConfigPathChar() const { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return std::vector<char>(); | |||
| @@ -355,7 +349,7 @@ std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const { | |||
| return StringToChar(ref); | |||
| } | |||
| void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) { | |||
| void AscendDeviceInfo::SetInputFormat(const std::vector<char> &format) { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return; | |||
| @@ -363,7 +357,7 @@ void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) { | |||
| data_->params[kModelOptionAscend310InputFormat] = CharToString(format); | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetInputFormatChar() const { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return std::vector<char>(); | |||
| @@ -372,14 +366,14 @@ std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const { | |||
| return StringToChar(ref); | |||
| } | |||
| void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) { | |||
| void AscendDeviceInfo::SetInputShape(const std::vector<char> &shape) { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return; | |||
| } | |||
| data_->params[kModelOptionAscend310InputShape] = CharToString(shape); | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetInputShapeChar() const { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return std::vector<char>(); | |||
| @@ -388,7 +382,7 @@ std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const { | |||
| return StringToChar(ref); | |||
| } | |||
| void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) { | |||
| void AscendDeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return; | |||
| @@ -403,7 +397,7 @@ void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic | |||
| data_->params[kModelOptionAscend310DynamicBatchSize] = batchs; | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetDynamicBatchSizeChar() const { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return std::vector<char>(); | |||
| @@ -412,7 +406,7 @@ std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const { | |||
| return StringToChar(ref); | |||
| } | |||
| void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) { | |||
| void AscendDeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return; | |||
| @@ -420,7 +414,7 @@ void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_i | |||
| data_->params[kModelOptionAscend310DynamicImageSize] = CharToString(dynamic_image_size); | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetDynamicImageSizeChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return std::vector<char>(); | |||
| @@ -429,7 +423,7 @@ std::vector<char> Ascend310DeviceInfo::GetDynamicImageSizeChar() const { | |||
| return StringToChar(ref); | |||
| } | |||
| void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) { | |||
| void AscendDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return; | |||
| @@ -437,7 +431,7 @@ void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mo | |||
| data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode); | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetPrecisionModeChar() const { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return std::vector<char>(); | |||
| @@ -446,7 +440,7 @@ std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const { | |||
| return StringToChar(ref); | |||
| } | |||
| void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) { | |||
| void AscendDeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return; | |||
| @@ -454,7 +448,7 @@ void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select | |||
| data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode); | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetOpSelectImplModeChar() const { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return std::vector<char>(); | |||
| @@ -463,14 +457,14 @@ std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const { | |||
| return StringToChar(ref); | |||
| } | |||
| void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) { | |||
| void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return; | |||
| } | |||
| data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path); | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return std::vector<char>(); | |||
| @@ -479,7 +473,7 @@ std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const { | |||
| return StringToChar(ref); | |||
| } | |||
| void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) { | |||
| void AscendDeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return; | |||
| @@ -487,7 +481,7 @@ void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> | |||
| data_->params[kModelOptionAscend310InputShapeMap] = shape; | |||
| } | |||
| std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const { | |||
| std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return std::map<int, std::vector<int>>(); | |||
| @@ -495,7 +489,7 @@ std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const { | |||
| return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap); | |||
| } | |||
| void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) { | |||
| void AscendDeviceInfo::SetOutputType(enum DataType output_type) { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return; | |||
| @@ -503,7 +497,7 @@ void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) { | |||
| data_->params[kModelOptionAscend310OutputType] = output_type; | |||
| } | |||
| enum DataType Ascend310DeviceInfo::GetOutputType() const { | |||
| enum DataType AscendDeviceInfo::GetOutputType() const { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return DataType::kTypeUnknown; | |||
| @@ -511,7 +505,7 @@ enum DataType Ascend310DeviceInfo::GetOutputType() const { | |||
| return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType); | |||
| } | |||
| void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) { | |||
| void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return; | |||
| @@ -519,7 +513,7 @@ void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_ | |||
| data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode); | |||
| } | |||
| std::vector<char> Ascend310DeviceInfo::GetBufferOptimizeModeChar() const { | |||
| std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const { | |||
| if (data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid context."; | |||
| return std::vector<char>(); | |||
| @@ -104,7 +104,7 @@ lite::InnerContext *ContextUtils::Convert(Context *context) { | |||
| } else if (device->GetDeviceType() == kKirinNPU) { | |||
| auto npu_context = device->Cast<KirinNPUDeviceInfo>(); | |||
| ret = AddNpuDevice(npu_context->GetFrequency(), inner_context.get()); | |||
| } else if (device->GetDeviceType() == kAscend310) { | |||
| } else if (device->GetDeviceType() == kAscend) { | |||
| ret = AddAscend310Device(inner_context.get(), device.get()); | |||
| } | |||
| if (ret != kSuccess) { | |||
| @@ -71,11 +71,11 @@ std::shared_ptr<mindspore::Context> Common::ContextAutoSet() { | |||
| auto context = std::make_shared<mindspore::Context>(); | |||
| if (device_target_str == "Ascend310") { | |||
| auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>(); | |||
| auto ascend310_info = std::make_shared<mindspore::AscendDeviceInfo>(); | |||
| ascend310_info->SetDeviceID(device_id); | |||
| context->MutableDeviceInfo().emplace_back(ascend310_info); | |||
| } else if (device_target_str == "Ascend910") { | |||
| auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>(); | |||
| auto ascend310_info = std::make_shared<mindspore::AscendDeviceInfo>(); | |||
| ascend310_info->SetDeviceID(device_id); | |||
| context->MutableDeviceInfo().emplace_back(ascend310_info); | |||
| } else { | |||
| @@ -101,7 +101,7 @@ TEST_F(TestDE, TestDvpp) { | |||
| auto context = ContextAutoSet(); | |||
| ASSERT_TRUE(context != nullptr); | |||
| ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); | |||
| auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>(); | |||
| auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>(); | |||
| ASSERT_TRUE(ascend310_info != nullptr); | |||
| auto device_id = ascend310_info->GetDeviceID(); | |||
| @@ -154,7 +154,7 @@ TEST_F(TestDE, TestDvppSinkMode) { | |||
| auto context = ContextAutoSet(); | |||
| ASSERT_TRUE(context != nullptr); | |||
| ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); | |||
| auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>(); | |||
| auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>(); | |||
| ASSERT_TRUE(ascend310_info != nullptr); | |||
| auto device_id = ascend310_info->GetDeviceID(); | |||
| @@ -202,7 +202,7 @@ TEST_F(TestDE, TestDvppDecodeResizeCropNormalize) { | |||
| auto context = ContextAutoSet(); | |||
| ASSERT_TRUE(context != nullptr); | |||
| ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); | |||
| auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>(); | |||
| auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>(); | |||
| ASSERT_TRUE(ascend310_info != nullptr); | |||
| auto device_id = ascend310_info->GetDeviceID(); | |||
| @@ -38,7 +38,7 @@ TEST_F(TestDynamicBatchSize, InferMindIR) { | |||
| auto context = ContextAutoSet(); | |||
| ASSERT_TRUE(context != nullptr); | |||
| ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); | |||
| auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>(); | |||
| auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>(); | |||
| ASSERT_TRUE(ascend310_info != nullptr); | |||
| std::map<int, std::vector<int>> input_shape; | |||
| @@ -59,7 +59,7 @@ TEST_F(TestZeroCopy, TestMindIR) { | |||
| auto context = ContextAutoSet(); | |||
| ASSERT_TRUE(context != nullptr); | |||
| ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); | |||
| auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>(); | |||
| auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>(); | |||
| ASSERT_TRUE(ascend310_info != nullptr); | |||
| ascend310_info->SetInsertOpConfigPath(aipp_path); | |||
| auto device_id = ascend310_info->GetDeviceID(); | |||
| @@ -107,7 +107,7 @@ TEST_F(TestZeroCopy, TestDeviceTensor) { | |||
| auto context = ContextAutoSet(); | |||
| ASSERT_TRUE(context != nullptr); | |||
| ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); | |||
| auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>(); | |||
| auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>(); | |||
| ASSERT_TRUE(ascend310_info != nullptr); | |||
| ascend310_info->SetInsertOpConfigPath(aipp_path); | |||
| auto device_id = ascend310_info->GetDeviceID(); | |||
| @@ -27,32 +27,27 @@ TEST_F(TestCxxApiContext, test_context_device_info_cast_SUCCESS) { | |||
| std::shared_ptr<DeviceInfoContext> cpu = std::make_shared<CPUDeviceInfo>(); | |||
| std::shared_ptr<DeviceInfoContext> gpu = std::make_shared<GPUDeviceInfo>(); | |||
| std::shared_ptr<DeviceInfoContext> kirin_npu = std::make_shared<KirinNPUDeviceInfo>(); | |||
| std::shared_ptr<DeviceInfoContext> ascend310 = std::make_shared<Ascend310DeviceInfo>(); | |||
| std::shared_ptr<DeviceInfoContext> ascend910 = std::make_shared<Ascend910DeviceInfo>(); | |||
| std::shared_ptr<DeviceInfoContext> ascend = std::make_shared<AscendDeviceInfo>(); | |||
| ASSERT_TRUE(cpu->Cast<CPUDeviceInfo>() != nullptr); | |||
| ASSERT_TRUE(gpu->Cast<GPUDeviceInfo>() != nullptr); | |||
| ASSERT_TRUE(kirin_npu->Cast<KirinNPUDeviceInfo>() != nullptr); | |||
| ASSERT_TRUE(ascend310->Cast<Ascend310DeviceInfo>() != nullptr); | |||
| ASSERT_TRUE(ascend910->Cast<Ascend910DeviceInfo>() != nullptr); | |||
| ASSERT_TRUE(ascend->Cast<AscendDeviceInfo>() != nullptr); | |||
| } | |||
| TEST_F(TestCxxApiContext, test_context_device_info_cast_FAILED) { | |||
| std::shared_ptr<DeviceInfoContext> cpu = std::make_shared<CPUDeviceInfo>(); | |||
| std::shared_ptr<DeviceInfoContext> gpu = std::make_shared<GPUDeviceInfo>(); | |||
| std::shared_ptr<DeviceInfoContext> kirin_npu = std::make_shared<KirinNPUDeviceInfo>(); | |||
| std::shared_ptr<DeviceInfoContext> ascend310 = std::make_shared<Ascend310DeviceInfo>(); | |||
| std::shared_ptr<DeviceInfoContext> ascend910 = std::make_shared<Ascend910DeviceInfo>(); | |||
| std::shared_ptr<DeviceInfoContext> ascend = std::make_shared<AscendDeviceInfo>(); | |||
| ASSERT_TRUE(cpu->Cast<GPUDeviceInfo>() == nullptr); | |||
| ASSERT_TRUE(kirin_npu->Cast<GPUDeviceInfo>() == nullptr); | |||
| ASSERT_TRUE(ascend310->Cast<GPUDeviceInfo>() == nullptr); | |||
| ASSERT_TRUE(ascend910->Cast<GPUDeviceInfo>() == nullptr); | |||
| ASSERT_TRUE(ascend->Cast<GPUDeviceInfo>() == nullptr); | |||
| ASSERT_TRUE(gpu->Cast<CPUDeviceInfo>() == nullptr); | |||
| ASSERT_TRUE(kirin_npu->Cast<CPUDeviceInfo>() == nullptr); | |||
| ASSERT_TRUE(ascend310->Cast<CPUDeviceInfo>() == nullptr); | |||
| ASSERT_TRUE(ascend910->Cast<CPUDeviceInfo>() == nullptr); | |||
| ASSERT_TRUE(ascend->Cast<CPUDeviceInfo>() == nullptr); | |||
| } | |||
| TEST_F(TestCxxApiContext, test_context_get_set_SUCCESS) { | |||
| @@ -86,7 +81,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) { | |||
| std::string option_9_ans = "1,2,3,4,5"; | |||
| auto context = std::make_shared<Context>(); | |||
| std::shared_ptr<Ascend310DeviceInfo> ascend310 = std::make_shared<Ascend310DeviceInfo>(); | |||
| std::shared_ptr<AscendDeviceInfo> ascend310 = std::make_shared<AscendDeviceInfo>(); | |||
| ascend310->SetInputShape(option_1); | |||
| ascend310->SetInsertOpConfigPath(option_2); | |||
| ascend310->SetOpSelectImplMode(option_3); | |||
| @@ -99,7 +94,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) { | |||
| context->MutableDeviceInfo().push_back(ascend310); | |||
| ASSERT_EQ(context->MutableDeviceInfo().size(), 1); | |||
| auto ctx = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>(); | |||
| auto ctx = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>(); | |||
| ASSERT_TRUE(ctx != nullptr); | |||
| ASSERT_EQ(ascend310->GetInputShape(), option_1); | |||
| ASSERT_EQ(ascend310->GetInsertOpConfigPath(), option_2); | |||
| @@ -113,7 +108,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) { | |||
| } | |||
| TEST_F(TestCxxApiContext, test_context_ascend310_context_default_value_SUCCESS) { | |||
| auto ctx = std::make_shared<Ascend310DeviceInfo>(); | |||
| auto ctx = std::make_shared<AscendDeviceInfo>(); | |||
| ASSERT_EQ(ctx->GetOpSelectImplMode(), ""); | |||
| } | |||
| } // namespace mindspore | |||