Browse Source

check modeltype

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
tags/v1.2.0-rc1
zhoufeng 4 years ago
parent
commit
539d88552a
2 changed files with 24 additions and 3 deletions
  1. +1
    -0
      include/api/context.h
  2. +23
    -3
      mindspore/ccsrc/cxx_api/model/model.cc

+ 1
- 0
include/api/context.h View File

@@ -25,6 +25,7 @@
namespace mindspore {
constexpr auto kDeviceTypeAscend310 = "Ascend310";
constexpr auto kDeviceTypeAscend910 = "Ascend910";
constexpr auto kDeviceTypeGPU = "GPU";

struct MS_API Context {
virtual ~Context() = default;


+ 23
- 3
mindspore/ccsrc/cxx_api/model/model.cc View File

@@ -20,6 +20,13 @@
#include "utils/utils.h"

namespace mindspore {
namespace {
const std::map<std::string, std::set<ModelType>> kSupportedModelMap = {
{kDeviceTypeAscend310, {kOM, kMindIR}},
{kDeviceTypeAscend910, {kMindIR}},
{kDeviceTypeGPU, {kMindIR}},
};
}
Status Model::Build() {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->Build();
@@ -61,8 +68,21 @@ Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context>

Model::~Model() {}

bool Model::CheckModelSupport(const std::string &device_type, ModelType) {
return Factory<ModelImpl>::Instance().CheckModelSupport(device_type);
}
bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) {
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type)) {
return false;
}

auto first_iter = kSupportedModelMap.find(device_type);
if (first_iter == kSupportedModelMap.end()) {
return false;
}

auto secend_iter = first_iter->second.find(model_type);
if (secend_iter == first_iter->second.end()) {
return false;
}

return true;
}
} // namespace mindspore

Loading…
Cancel
Save