|
|
|
@@ -20,6 +20,13 @@ |
|
|
|
#include "utils/utils.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace { |
|
|
|
const std::map<std::string, std::set<ModelType>> kSupportedModelMap = { |
|
|
|
{kDeviceTypeAscend310, {kOM, kMindIR}}, |
|
|
|
{kDeviceTypeAscend910, {kMindIR}}, |
|
|
|
{kDeviceTypeGPU, {kMindIR}}, |
|
|
|
}; |
|
|
|
} |
|
|
|
Status Model::Build() { |
|
|
|
MS_EXCEPTION_IF_NULL(impl_); |
|
|
|
return impl_->Build(); |
|
|
|
@@ -61,8 +68,21 @@ Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context> |
|
|
|
|
|
|
|
Model::~Model() {} |
|
|
|
|
|
|
|
bool Model::CheckModelSupport(const std::string &device_type, ModelType) { |
|
|
|
return Factory<ModelImpl>::Instance().CheckModelSupport(device_type); |
|
|
|
} |
|
|
|
bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) { |
|
|
|
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
auto first_iter = kSupportedModelMap.find(device_type); |
|
|
|
if (first_iter == kSupportedModelMap.end()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
auto secend_iter = first_iter->second.find(model_type); |
|
|
|
if (secend_iter == first_iter->second.end()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
} // namespace mindspore |