diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc b/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc index 3c782576af..1e9392a17b 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc @@ -16,9 +16,12 @@ #include "cxx_api/model/acl/acl_model.h" #include +#include +#include #include "include/api/context.h" #include "cxx_api/factory.h" #include "cxx_api/graph/acl/acl_env_guard.h" +#include "acl/acl_base.h" namespace mindspore { API_FACTORY_REG(ModelImpl, Ascend310, AclModel); @@ -165,4 +168,22 @@ std::vector AclModel::GetOutputs() { MS_EXCEPTION_IF_NULL(graph_cell_); return graph_cell_->GetOutputs(); } + +bool AclModel::CheckModelSupport(enum ModelType model_type) { + const char *soc_name_c = aclrtGetSocName(); + if (soc_name_c == nullptr) { + return false; + } + std::string soc_name(soc_name_c); + if (soc_name.find("910") != std::string::npos) { + return false; + } + + static const std::set kSupportedModelMap = {kMindIR, kOM}; + auto iter = kSupportedModelMap.find(model_type); + if (iter == kSupportedModelMap.end()) { + return false; + } + return true; +} } // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model.h b/mindspore/ccsrc/cxx_api/model/acl/acl_model.h index be538e18c3..5520299216 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model.h +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model.h @@ -43,6 +43,8 @@ class AclModel : public ModelImpl { std::vector GetInputs() override; std::vector GetOutputs() override; + bool CheckModelSupport(enum ModelType model_type) override; + private: ModelConverter model_converter_; std::unique_ptr options_; diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc b/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc index 82dec9abaf..c83ca6af47 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc @@ -17,6 +17,7 @@ #include #include "utils/log_adapter.h" #include "external/ge/ge_api_types.h" +#include "acl/acl_base.h" namespace mindspore { static const std::map kSupportedDtypeOptionMap = {{DataType::kNumberTypeFloat16, "FP16"}, @@ -55,6 +56,12 @@ AclModelOptions::AclModelOptions(const std::shared_ptr &context) { device_id_ = ascend310_info->GetDeviceID(); dump_cfg_path_ = ascend310_info->GetDumpConfigPath(); buffer_optimize_mode_ = ascend310_info->GetBufferOptimizeMode(); + const char *soc_name = aclrtGetSocName(); + if (soc_name == nullptr) { + MS_LOG(WARNING) << "Get soc version failed."; + return; + } + soc_version_ = soc_name; } void AclModelOptions::RenameInput(const std::vector &input_names) { diff --git a/mindspore/ccsrc/cxx_api/model/model.cc b/mindspore/ccsrc/cxx_api/model/model.cc index be81447ac1..88f1913688 100644 --- a/mindspore/ccsrc/cxx_api/model/model.cc +++ b/mindspore/ccsrc/cxx_api/model/model.cc @@ -21,12 +21,6 @@ namespace mindspore { namespace { -const std::map> kSupportedModelMap = { - {kAscend310, {kOM, kMindIR}}, - {kAscend910, {kMindIR}}, - {kNvidiaGPU, {kMindIR}}, -}; - std::string GetDeviceTypeString(enum DeviceType type) { static const std::map kDeviceTypeStrs = { {kCPU, "CPU"}, {kMaliGPU, "MaliGPU"}, {kNvidiaGPU, "GPU"}, @@ -144,16 +138,11 @@ bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) return false; } - auto first_iter = kSupportedModelMap.find(device_type); - if (first_iter == kSupportedModelMap.end()) { - return false; - } - - auto secend_iter = first_iter->second.find(model_type); - if (secend_iter == first_iter->second.end()) { + auto check_model = Factory::Instance().Create(device_type_str); + if (check_model == nullptr) { return false; } - return true; + return check_model->CheckModelSupport(model_type); } } // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/model/model_impl.h b/mindspore/ccsrc/cxx_api/model/model_impl.h index c0456fb13a..8dde62d7f5 100644 --- a/mindspore/ccsrc/cxx_api/model/model_impl.h +++ b/mindspore/ccsrc/cxx_api/model/model_impl.h @@ -42,6 +42,8 @@ class ModelImpl { virtual std::vector GetInputs() = 0; virtual std::vector GetOutputs() = 0; + virtual bool CheckModelSupport(enum ModelType model_type) { return false; } + protected: Status Load(const std::shared_ptr &graph_cell, uint32_t device_id) { MS_EXCEPTION_IF_NULL(graph_cell); diff --git a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc index 51a3cb0756..8328ff6eea 100644 --- a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc +++ b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc @@ -16,6 +16,7 @@ #include "cxx_api/model/ms/ms_model.h" #include +#include #include "include/api/context.h" #include "utils/ms_context.h" #include "cxx_api/factory.h" @@ -169,4 +170,13 @@ uint32_t MsModel::GetDeviceID() const { return 0; } + +bool MsModel::CheckModelSupport(enum ModelType model_type) { + static const std::set kSupportedModelMap = {kMindIR}; + auto iter = kSupportedModelMap.find(model_type); + if (iter == kSupportedModelMap.end()) { + return false; + } + return true; +} } // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/model/ms/ms_model.h b/mindspore/ccsrc/cxx_api/model/ms/ms_model.h index aec854c624..78cefd4eae 100644 --- a/mindspore/ccsrc/cxx_api/model/ms/ms_model.h +++ b/mindspore/ccsrc/cxx_api/model/ms/ms_model.h @@ -44,6 +44,8 @@ class MsModel : public ModelImpl { std::vector GetInputs() override; std::vector GetOutputs() override; + bool CheckModelSupport(enum ModelType model_type) override; + private: std::shared_ptr GenerateGraphCell(const std::vector> &dims); uint32_t GetDeviceID() const;