Browse Source

add 310/910 support-check

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
tags/v1.3.0
zhoufeng 4 years ago
parent
commit
b6264c297e
7 changed files with 47 additions and 14 deletions
  1. +21
    -0
      mindspore/ccsrc/cxx_api/model/acl/acl_model.cc
  2. +2
    -0
      mindspore/ccsrc/cxx_api/model/acl/acl_model.h
  3. +7
    -0
      mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc
  4. +3
    -14
      mindspore/ccsrc/cxx_api/model/model.cc
  5. +2
    -0
      mindspore/ccsrc/cxx_api/model/model_impl.h
  6. +10
    -0
      mindspore/ccsrc/cxx_api/model/ms/ms_model.cc
  7. +2
    -0
      mindspore/ccsrc/cxx_api/model/ms/ms_model.h

+ 21
- 0
mindspore/ccsrc/cxx_api/model/acl/acl_model.cc View File

@@ -16,9 +16,12 @@

#include "cxx_api/model/acl/acl_model.h"
#include <memory>
#include <string>
#include <set>
#include "include/api/context.h"
#include "cxx_api/factory.h"
#include "cxx_api/graph/acl/acl_env_guard.h"
#include "acl/acl_base.h"

namespace mindspore {
API_FACTORY_REG(ModelImpl, Ascend310, AclModel);
@@ -165,4 +168,22 @@ std::vector<MSTensor> AclModel::GetOutputs() {
MS_EXCEPTION_IF_NULL(graph_cell_);
return graph_cell_->GetOutputs();
}

bool AclModel::CheckModelSupport(enum ModelType model_type) {
const char *soc_name_c = aclrtGetSocName();
if (soc_name_c == nullptr) {
return false;
}
std::string soc_name(soc_name_c);
if (soc_name.find("910") != std::string::npos) {
return false;
}

static const std::set<ModelType> kSupportedModelMap = {kMindIR, kOM};
auto iter = kSupportedModelMap.find(model_type);
if (iter == kSupportedModelMap.end()) {
return false;
}
return true;
}
} // namespace mindspore

+ 2
- 0
mindspore/ccsrc/cxx_api/model/acl/acl_model.h View File

@@ -43,6 +43,8 @@ class AclModel : public ModelImpl {
std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override;

bool CheckModelSupport(enum ModelType model_type) override;

private:
ModelConverter model_converter_;
std::unique_ptr<AclModelOptions> options_;


+ 7
- 0
mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc View File

@@ -17,6 +17,7 @@
#include <memory>
#include "utils/log_adapter.h"
#include "external/ge/ge_api_types.h"
#include "acl/acl_base.h"

namespace mindspore {
static const std::map<enum DataType, std::string> kSupportedDtypeOptionMap = {{DataType::kNumberTypeFloat16, "FP16"},
@@ -55,6 +56,12 @@ AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
device_id_ = ascend310_info->GetDeviceID();
dump_cfg_path_ = ascend310_info->GetDumpConfigPath();
buffer_optimize_mode_ = ascend310_info->GetBufferOptimizeMode();
const char *soc_name = aclrtGetSocName();
if (soc_name == nullptr) {
MS_LOG(WARNING) << "Get soc version failed.";
return;
}
soc_version_ = soc_name;
}

void AclModelOptions::RenameInput(const std::vector<std::string> &input_names) {


+ 3
- 14
mindspore/ccsrc/cxx_api/model/model.cc View File

@@ -21,12 +21,6 @@

namespace mindspore {
namespace {
const std::map<enum DeviceType, std::set<ModelType>> kSupportedModelMap = {
{kAscend310, {kOM, kMindIR}},
{kAscend910, {kMindIR}},
{kNvidiaGPU, {kMindIR}},
};

std::string GetDeviceTypeString(enum DeviceType type) {
static const std::map<enum DeviceType, std::string> kDeviceTypeStrs = {
{kCPU, "CPU"}, {kMaliGPU, "MaliGPU"}, {kNvidiaGPU, "GPU"},
@@ -144,16 +138,11 @@ bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type)
return false;
}

auto first_iter = kSupportedModelMap.find(device_type);
if (first_iter == kSupportedModelMap.end()) {
return false;
}

auto secend_iter = first_iter->second.find(model_type);
if (secend_iter == first_iter->second.end()) {
auto check_model = Factory<ModelImpl>::Instance().Create(device_type_str);
if (check_model == nullptr) {
return false;
}

return true;
return check_model->CheckModelSupport(model_type);
}
} // namespace mindspore

+ 2
- 0
mindspore/ccsrc/cxx_api/model/model_impl.h View File

@@ -42,6 +42,8 @@ class ModelImpl {
virtual std::vector<MSTensor> GetInputs() = 0;
virtual std::vector<MSTensor> GetOutputs() = 0;

virtual bool CheckModelSupport(enum ModelType model_type) { return false; }

protected:
Status Load(const std::shared_ptr<GraphCell> &graph_cell, uint32_t device_id) {
MS_EXCEPTION_IF_NULL(graph_cell);


+ 10
- 0
mindspore/ccsrc/cxx_api/model/ms/ms_model.cc View File

@@ -16,6 +16,7 @@

#include "cxx_api/model/ms/ms_model.h"
#include <memory>
#include <set>
#include "include/api/context.h"
#include "utils/ms_context.h"
#include "cxx_api/factory.h"
@@ -169,4 +170,13 @@ uint32_t MsModel::GetDeviceID() const {

return 0;
}

bool MsModel::CheckModelSupport(enum ModelType model_type) {
static const std::set<ModelType> kSupportedModelMap = {kMindIR};
auto iter = kSupportedModelMap.find(model_type);
if (iter == kSupportedModelMap.end()) {
return false;
}
return true;
}
} // namespace mindspore

+ 2
- 0
mindspore/ccsrc/cxx_api/model/ms/ms_model.h View File

@@ -44,6 +44,8 @@ class MsModel : public ModelImpl {
std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override;

bool CheckModelSupport(enum ModelType model_type) override;

private:
std::shared_ptr<GraphCell> GenerateGraphCell(const std::vector<std::vector<int64_t>> &dims);
uint32_t GetDeviceID() const;


Loading…
Cancel
Save