/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/servable.h" #include #include #include "worker/preprocess.h" #include "worker/postprocess.h" namespace mindspore::serving { std::string ServableMeta::Repr() const { std::ostringstream stream; stream << "path(" << servable_name << ") file(" << servable_file + ")"; return stream.str(); } void ServableMeta::SetModelFormat(const std::string &format) { if (format == "om") { model_format = kOM; } else if (format == "mindir") { model_format = kMindIR; } else { MSI_LOG_ERROR << "Invalid model format " << format; } } std::string LoadServableSpec::Repr() const { std::string version; if (version_number > 0) { version = " version(" + std::to_string(version_number) + ") "; } return "servable(" + servable_name + ") " + version; } std::string WorkerSpec::Repr() const { std::string version; if (version_number > 0) { version = " version(" + std::to_string(version_number) + ") "; } return "servable(" + servable_name + ") " + version + " address(" + worker_address + ") "; } std::string RequestSpec::Repr() const { std::string version; if (version_number > 0) { version = " version(" + std::to_string(version_number) + ") "; } return "servable(" + servable_name + ") " + "method(" + method_name + ") " + version; } Status ServableSignature::Check() const { std::set method_set; std::string model_str = servable_meta.Repr(); for (auto &method : methods) { if (method_set.count(method.method_name) > 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " " << method.method_name << " has been defined repeatly"; } method_set.emplace(method.method_name); size_t preprocess_outputs_count = 0; size_t postprocess_outputs_count = 0; const auto &preprocess_name = method.preprocess_name; if (!preprocess_name.empty()) { auto preprocess = PreprocessStorage::Instance().GetPreprocess(preprocess_name); if (preprocess == nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name << " preprocess " << preprocess_name << " not defined"; } preprocess_outputs_count = preprocess->GetOutputsCount(preprocess_name); for (size_t i = 0; i < method.preprocess_inputs.size(); i++) { auto &input = method.preprocess_inputs[i]; if (input.first != kPredictPhaseTag_Input) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name << ", the preprocess " << i << "th input phase tag " << input.first << " invalid"; } if (input.second >= method.inputs.size()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name << ", the preprocess " << i << "th input uses method " << input.second << "th input, that is greater than the method inputs size " << method.inputs.size(); } } } for (size_t i = 0; i < method.servable_inputs.size(); i++) { auto &input = method.servable_inputs[i]; if (input.first == kPredictPhaseTag_Input) { if (input.second >= method.inputs.size()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name << ", the servable " << i << "th input uses method " << input.second << "th input, that is greater than the method inputs size " << method.inputs.size(); } } else if (input.first == kPredictPhaseTag_Preproces) { if (input.second >= preprocess_outputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name << ", the servable " << i << "th input uses preprocess " << input.second << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; } } else { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name << ", the servable " << i << "th input phase tag " << input.first << " invalid"; } } const auto &postprocess_name = method.postprocess_name; if (!method.postprocess_name.empty()) { auto postprocess = PostprocessStorage::Instance().GetPostprocess(postprocess_name); if (postprocess == nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name << " postprocess " << postprocess_name << " not defined"; } postprocess_outputs_count = postprocess->GetOutputsCount(postprocess_name); for (size_t i = 0; i < method.postprocess_inputs.size(); i++) { auto &input = method.postprocess_inputs[i]; if (input.first == kPredictPhaseTag_Input) { if (input.second >= method.inputs.size()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i << "th input uses method " << input.second << "th input, that is greater than the method inputs size " << method.inputs.size(); } } else if (input.first == kPredictPhaseTag_Preproces) { if (input.second >= preprocess_outputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i << "th input uses preprocess " << input.second << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; } } else if (input.first == kPredictPhaseTag_Predict) { if (input.second >= servable_meta.outputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i << "th input uses servable " << input.second << "th output, that is greater than the servable outputs size " << servable_meta.outputs_count; } } else { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i << "th input phase tag " << input.first << " invalid"; } } } for (size_t i = 0; i < method.returns.size(); i++) { auto &input = method.returns[i]; if (input.first == kPredictPhaseTag_Input) { if (input.second >= method.inputs.size()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name << ", the method " << i << "th output uses method " << input.second << "th input, that is greater than the method inputs size " << method.inputs.size(); } } else if (input.first == kPredictPhaseTag_Preproces) { if (input.second >= preprocess_outputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name << ", the method " << i << "th output uses preprocess " << input.second << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; } } else if (input.first == kPredictPhaseTag_Predict) { if (input.second >= servable_meta.outputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name << ", the method " << i << "th output uses servable " << input.second << "th output, that is greater than the servable outputs size " << servable_meta.outputs_count; } } else if (input.first == kPredictPhaseTag_Postprocess) { if (input.second >= postprocess_outputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name << ", the method " << i << "th output uses postprocess " << input.second << "th output, that is greater than the postprocess outputs size " << postprocess_outputs_count; } } else { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name << ", the method " << i << "th output phase tag " << input.first << " invalid"; } } } return SUCCESS; } bool ServableSignature::GetMethodDeclare(const std::string &method_name, MethodSignature *method) { MSI_EXCEPTION_IF_NULL(method); auto item = find_if(methods.begin(), methods.end(), [&](const MethodSignature &v) { return v.method_name == method_name; }); if (item != methods.end()) { *method = *item; return true; } return false; } void ServableStorage::Register(const ServableSignature &def) { auto model_name = def.servable_meta.servable_name; if (servable_signatures_map_.find(model_name) == servable_signatures_map_.end()) { MSI_LOG_WARNING << "Servable " << model_name << " has already been defined"; } servable_signatures_map_[model_name] = def; } bool ServableStorage::GetServableDef(const std::string &model_name, ServableSignature *def) const { MSI_EXCEPTION_IF_NULL(def); auto it = servable_signatures_map_.find(model_name); if (it == servable_signatures_map_.end()) { return false; } *def = it->second; return true; } std::shared_ptr ServableStorage::Instance() { static std::shared_ptr storage; if (storage == nullptr) { storage = std::make_shared(); } return storage; } void ServableStorage::RegisterMethod(const MethodSignature &method) { MSI_LOG_INFO << "Declare method " << method.method_name << ", servable " << method.servable_name; auto it = servable_signatures_map_.find(method.servable_name); if (it == servable_signatures_map_.end()) { ServableSignature signature; signature.methods.push_back(method); servable_signatures_map_[method.servable_name] = signature; return; } it->second.methods.push_back(method); } void ServableStorage::DeclareServable(const mindspore::serving::ServableMeta &servable) { MSI_LOG_INFO << "Declare servable " << servable.servable_name; auto it = servable_signatures_map_.find(servable.servable_name); if (it == servable_signatures_map_.end()) { ServableSignature signature; signature.servable_meta = servable; servable_signatures_map_[servable.servable_name] = signature; return; } it->second.servable_meta = servable; } void ServableStorage::RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count, size_t outputs_count) { auto it = servable_signatures_map_.find(servable_name); if (it == servable_signatures_map_.end()) { MSI_LOG_EXCEPTION << "RegisterInputOutputInfo failed, cannot find servable " << servable_name; } auto &servable_meta = it->second.servable_meta; if (servable_meta.inputs_count != 0 && servable_meta.inputs_count != inputs_count) { MSI_LOG_EXCEPTION << "RegisterInputOutputInfo failed, inputs count " << inputs_count << " not match old count " << servable_meta.inputs_count << ",servable name " << servable_name; } if (servable_meta.outputs_count != 0 && servable_meta.outputs_count != outputs_count) { MSI_LOG_EXCEPTION << "RegisterInputOutputInfo failed, outputs count " << outputs_count << " not match old count " << servable_meta.outputs_count << ",servable name " << servable_name; } servable_meta.inputs_count = inputs_count; servable_meta.outputs_count = outputs_count; } std::vector ServableStorage::GetInputOutputInfo(const std::string &servable_name) const { std::vector result; auto it = servable_signatures_map_.find(servable_name); if (it == servable_signatures_map_.end()) { return result; } result.push_back(it->second.servable_meta.inputs_count); result.push_back(it->second.servable_meta.outputs_count); return result; } } // namespace mindspore::serving