/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "include/api/model.h" #include "include/api/context.h" #include "cxx_api/model/model_impl.h" #include "cxx_api/factory.h" #include "utils/utils.h" namespace mindspore { namespace { const std::map> kSupportedModelMap = { {kAscend310, {kOM, kMindIR}}, {kAscend910, {kMindIR}}, {kNvidiaGPU, {kMindIR}}, }; std::string GetDeviceTypeString(enum DeviceType type) { static const std::map kDeviceTypeStrs = { {kCPU, "CPU"}, {kMaliGPU, "MaliGPU"}, {kNvidiaGPU, "GPU"}, {kKirinNPU, "KirinGPU"}, {kAscend910, "Ascend910"}, {kAscend310, "Ascend310"}, }; auto iter = kDeviceTypeStrs.find(type); if (iter != kDeviceTypeStrs.end()) { return iter->second; } return "InvalidDeviceType" + std::to_string(type); } } // namespace Status Model::Build(GraphCell graph_cell, const std::shared_ptr &model_context) { if (graph_cell.GetGraph() == nullptr) { MS_LOG(ERROR) << "Invalid graph input."; return kMCInvalidInput; } if (model_context == nullptr) { MS_LOG(ERROR) << "Invalid model context."; return kMCInvalidInput; } auto &device_info = model_context->MutableDeviceInfo(); if (device_info.size() != 1) { MS_LOG(ERROR) << "Invalid model context, only single device info is supported."; return kMCInvalidInput; } std::string device_target = GetDeviceTypeString(device_info[0]->GetDeviceType()); impl_ = Factory::Instance().Create(device_target); if (impl_ == nullptr) { MS_LOG(ERROR) << "Create session type " << device_target << " failed"; return kMEFailed; } g_device_target = device_target; impl_->SetGraph(std::make_shared(*graph_cell.GetGraph())); impl_->SetContext(model_context); return impl_->Build(); } Status Model::Resize(const std::vector &inputs, const std::vector> &dims) { if (impl_ == nullptr) { MS_LOG(ERROR) << "Failed because this model has not been built."; return kMCFailed; } return impl_->Resize(inputs, dims); } Status Model::Predict(const std::vector &inputs, std::vector *outputs) { if (impl_ == nullptr) { MS_LOG(ERROR) << "Failed because this model has not been built."; return kMCFailed; } return impl_->Predict(inputs, outputs); } std::vector Model::GetInputs() { if (impl_ == nullptr) { MS_LOG(ERROR) << "Failed because this model has not been built."; return {}; } return impl_->GetInputs(); } std::vector Model::GetOutputs() { if (impl_ == nullptr) { MS_LOG(ERROR) << "Failed because this model has not been built."; return {}; } return impl_->GetOutputs(); } MSTensor Model::GetInputByTensorName(const std::vector &tensor_name) { std::string tensor_name_str = CharToString(tensor_name); auto inputs = GetInputs(); for (auto in : inputs) { if (in.Name() == tensor_name_str) { return in; } } return MSTensor(nullptr); } std::vector> Model::GetOutputTensorNamesChar() { std::vector> ret; auto outputs = GetOutputs(); std::transform(outputs.begin(), outputs.end(), std::back_inserter(ret), [](MSTensor item) -> std::vector { return StringToChar(item.Name()); }); return ret; } MSTensor Model::GetOutputByTensorName(const std::vector &tensor_name) { std::string tensor_name_str = CharToString(tensor_name); auto outputs = GetOutputs(); for (auto out : outputs) { if (out.Name() == tensor_name_str) { return out; } } return MSTensor(nullptr); } Model::Model() : impl_(nullptr) {} Model::~Model() {} bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) { std::string device_type_str = GetDeviceTypeString(device_type); if (!Factory::Instance().CheckModelSupport(device_type_str)) { return false; } auto first_iter = kSupportedModelMap.find(device_type); if (first_iter == kSupportedModelMap.end()) { return false; } auto secend_iter = first_iter->second.find(model_type); if (secend_iter == first_iter->second.end()) { return false; } return true; } } // namespace mindspore