You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

servable.cc 14 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/servable.h"
  17. #include <set>
  18. #include <sstream>
  19. #include <map>
  20. #include "worker/preprocess.h"
  21. #include "worker/postprocess.h"
  22. namespace mindspore::serving {
  23. std::string ServableMeta::Repr() const {
  24. std::ostringstream stream;
  25. stream << "path(" << servable_name << ") file(" << servable_file + ")";
  26. return stream.str();
  27. }
  28. void ServableMeta::SetModelFormat(const std::string &format) {
  29. if (format == "om") {
  30. model_format = kOM;
  31. } else if (format == "mindir") {
  32. model_format = kMindIR;
  33. } else {
  34. MSI_LOG_ERROR << "Invalid model format " << format;
  35. }
  36. }
  37. std::string LoadServableSpec::Repr() const {
  38. std::string version;
  39. if (version_number > 0) {
  40. version = " version(" + std::to_string(version_number) + ") ";
  41. }
  42. return "servable(" + servable_name + ") " + version;
  43. }
  44. std::string WorkerSpec::Repr() const {
  45. std::string version;
  46. if (version_number > 0) {
  47. version = " version(" + std::to_string(version_number) + ") ";
  48. }
  49. return "servable(" + servable_name + ") " + version + " address(" + worker_address + ") ";
  50. }
  51. std::string RequestSpec::Repr() const {
  52. std::string version;
  53. if (version_number > 0) {
  54. version = " version(" + std::to_string(version_number) + ") ";
  55. }
  56. return "servable(" + servable_name + ") " + "method(" + method_name + ") " + version;
  57. }
  58. Status ServableSignature::Check() const {
  59. std::set<std::string> method_set;
  60. std::string model_str = servable_meta.Repr();
  61. for (auto &method : methods) {
  62. if (method_set.count(method.method_name) > 0) {
  63. return INFER_STATUS_LOG_ERROR(FAILED)
  64. << "Model " << model_str << " " << method.method_name << " has been defined repeatly";
  65. }
  66. method_set.emplace(method.method_name);
  67. size_t preprocess_outputs_count = 0;
  68. size_t postprocess_outputs_count = 0;
  69. const auto &preprocess_name = method.preprocess_name;
  70. if (!preprocess_name.empty()) {
  71. auto preprocess = PreprocessStorage::Instance().GetPreprocess(preprocess_name);
  72. if (preprocess == nullptr) {
  73. return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name
  74. << " preprocess " << preprocess_name << " not defined";
  75. }
  76. preprocess_outputs_count = preprocess->GetOutputsCount(preprocess_name);
  77. for (size_t i = 0; i < method.preprocess_inputs.size(); i++) {
  78. auto &input = method.preprocess_inputs[i];
  79. if (input.first != kPredictPhaseTag_Input) {
  80. return INFER_STATUS_LOG_ERROR(FAILED)
  81. << "Model " << model_str << " method " << method.method_name << ", the data of preprocess " << i
  82. << "th input cannot not come from '" << input.first << "'";
  83. }
  84. if (input.second >= method.inputs.size()) {
  85. return INFER_STATUS_LOG_ERROR(FAILED)
  86. << "Model " << model_str << " method " << method.method_name << ", the preprocess " << i
  87. << "th input uses method " << input.second << "th input, that is greater than the method inputs size "
  88. << method.inputs.size();
  89. }
  90. }
  91. }
  92. for (size_t i = 0; i < method.servable_inputs.size(); i++) {
  93. auto &input = method.servable_inputs[i];
  94. if (input.first == kPredictPhaseTag_Input) {
  95. if (input.second >= method.inputs.size()) {
  96. return INFER_STATUS_LOG_ERROR(FAILED)
  97. << "Model " << model_str << " method " << method.method_name << ", the servable " << i
  98. << "th input uses method " << input.second << "th input, that is greater than the method inputs size "
  99. << method.inputs.size();
  100. }
  101. } else if (input.first == kPredictPhaseTag_Preproces) {
  102. if (input.second >= preprocess_outputs_count) {
  103. return INFER_STATUS_LOG_ERROR(FAILED)
  104. << "Model " << model_str << " method " << method.method_name << ", the servable " << i
  105. << "th input uses preprocess " << input.second
  106. << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count;
  107. }
  108. } else {
  109. return INFER_STATUS_LOG_ERROR(FAILED)
  110. << "Model " << model_str << " method " << method.method_name << ", the data of servable " << i
  111. << "th input cannot not come from '" << input.first << "'";
  112. }
  113. }
  114. const auto &postprocess_name = method.postprocess_name;
  115. if (!method.postprocess_name.empty()) {
  116. auto postprocess = PostprocessStorage::Instance().GetPostprocess(postprocess_name);
  117. if (postprocess == nullptr) {
  118. return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name
  119. << " postprocess " << postprocess_name << " not defined";
  120. }
  121. postprocess_outputs_count = postprocess->GetOutputsCount(postprocess_name);
  122. for (size_t i = 0; i < method.postprocess_inputs.size(); i++) {
  123. auto &input = method.postprocess_inputs[i];
  124. if (input.first == kPredictPhaseTag_Input) {
  125. if (input.second >= method.inputs.size()) {
  126. return INFER_STATUS_LOG_ERROR(FAILED)
  127. << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i
  128. << "th input uses method " << input.second
  129. << "th input, that is greater than the method inputs size " << method.inputs.size();
  130. }
  131. } else if (input.first == kPredictPhaseTag_Preproces) {
  132. if (input.second >= preprocess_outputs_count) {
  133. return INFER_STATUS_LOG_ERROR(FAILED)
  134. << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i
  135. << "th input uses preprocess " << input.second
  136. << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count;
  137. }
  138. } else if (input.first == kPredictPhaseTag_Predict) {
  139. if (input.second >= servable_meta.outputs_count) {
  140. return INFER_STATUS_LOG_ERROR(FAILED)
  141. << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i
  142. << "th input uses servable " << input.second
  143. << "th output, that is greater than the servable outputs size " << servable_meta.outputs_count;
  144. }
  145. } else {
  146. return INFER_STATUS_LOG_ERROR(FAILED)
  147. << "Model " << model_str << " method " << method.method_name << ", the data of postprocess " << i
  148. << "th input cannot not come from '" << input.first << "'";
  149. }
  150. }
  151. }
  152. for (size_t i = 0; i < method.returns.size(); i++) {
  153. auto &input = method.returns[i];
  154. if (input.first == kPredictPhaseTag_Input) {
  155. if (input.second >= method.inputs.size()) {
  156. return INFER_STATUS_LOG_ERROR(FAILED)
  157. << "Model " << model_str << " method " << method.method_name << ", the method " << i
  158. << "th output uses method " << input.second << "th input, that is greater than the method inputs size "
  159. << method.inputs.size();
  160. }
  161. } else if (input.first == kPredictPhaseTag_Preproces) {
  162. if (input.second >= preprocess_outputs_count) {
  163. return INFER_STATUS_LOG_ERROR(FAILED)
  164. << "Model " << model_str << " method " << method.method_name << ", the method " << i
  165. << "th output uses preprocess " << input.second
  166. << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count;
  167. }
  168. } else if (input.first == kPredictPhaseTag_Predict) {
  169. if (input.second >= servable_meta.outputs_count) {
  170. return INFER_STATUS_LOG_ERROR(FAILED)
  171. << "Model " << model_str << " method " << method.method_name << ", the method " << i
  172. << "th output uses servable " << input.second
  173. << "th output, that is greater than the servable outputs size " << servable_meta.outputs_count;
  174. }
  175. } else if (input.first == kPredictPhaseTag_Postprocess) {
  176. if (input.second >= postprocess_outputs_count) {
  177. return INFER_STATUS_LOG_ERROR(FAILED)
  178. << "Model " << model_str << " method " << method.method_name << ", the method " << i
  179. << "th output uses postprocess " << input.second
  180. << "th output, that is greater than the postprocess outputs size " << postprocess_outputs_count;
  181. }
  182. } else {
  183. return INFER_STATUS_LOG_ERROR(FAILED)
  184. << "Model " << model_str << " method " << method.method_name << ", the data of method " << i
  185. << "th output cannot not come from '" << input.first << "'";
  186. }
  187. }
  188. }
  189. return SUCCESS;
  190. }
  191. bool ServableSignature::GetMethodDeclare(const std::string &method_name, MethodSignature *method) {
  192. MSI_EXCEPTION_IF_NULL(method);
  193. auto item =
  194. find_if(methods.begin(), methods.end(), [&](const MethodSignature &v) { return v.method_name == method_name; });
  195. if (item != methods.end()) {
  196. *method = *item;
  197. return true;
  198. }
  199. return false;
  200. }
  201. void ServableStorage::Register(const ServableSignature &def) {
  202. auto model_name = def.servable_meta.servable_name;
  203. if (servable_signatures_map_.find(model_name) == servable_signatures_map_.end()) {
  204. MSI_LOG_WARNING << "Servable " << model_name << " has already been defined";
  205. }
  206. servable_signatures_map_[model_name] = def;
  207. }
  208. bool ServableStorage::GetServableDef(const std::string &model_name, ServableSignature *def) const {
  209. MSI_EXCEPTION_IF_NULL(def);
  210. auto it = servable_signatures_map_.find(model_name);
  211. if (it == servable_signatures_map_.end()) {
  212. return false;
  213. }
  214. *def = it->second;
  215. return true;
  216. }
  217. ServableStorage &ServableStorage::Instance() {
  218. static ServableStorage storage;
  219. return storage;
  220. }
  221. Status ServableStorage::RegisterMethod(const MethodSignature &method) {
  222. MSI_LOG_INFO << "Declare method " << method.method_name << ", servable " << method.servable_name;
  223. auto it = servable_signatures_map_.find(method.servable_name);
  224. if (it == servable_signatures_map_.end()) {
  225. ServableSignature signature;
  226. signature.methods.push_back(method);
  227. servable_signatures_map_[method.servable_name] = signature;
  228. return SUCCESS;
  229. }
  230. for (auto &item : it->second.methods) {
  231. // cppcheck-suppress useStlAlgorithm
  232. if (item.method_name == method.method_name) {
  233. return INFER_STATUS_LOG_ERROR(FAILED)
  234. << "Method " << method.method_name << " has been registered more than once.";
  235. }
  236. }
  237. it->second.methods.push_back(method);
  238. return SUCCESS;
  239. }
  240. void ServableStorage::DeclareServable(const mindspore::serving::ServableMeta &servable) {
  241. MSI_LOG_INFO << "Declare servable " << servable.servable_name;
  242. auto it = servable_signatures_map_.find(servable.servable_name);
  243. if (it == servable_signatures_map_.end()) {
  244. ServableSignature signature;
  245. signature.servable_meta = servable;
  246. servable_signatures_map_[servable.servable_name] = signature;
  247. return;
  248. }
  249. it->second.servable_meta = servable;
  250. }
  251. Status ServableStorage::RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count,
  252. size_t outputs_count) {
  253. auto it = servable_signatures_map_.find(servable_name);
  254. if (it == servable_signatures_map_.end()) {
  255. return INFER_STATUS_LOG_ERROR(FAILED) << "RegisterInputOutputInfo failed, cannot find servable " << servable_name;
  256. }
  257. auto &servable_meta = it->second.servable_meta;
  258. if (servable_meta.inputs_count != 0 && servable_meta.inputs_count != inputs_count) {
  259. return INFER_STATUS_LOG_ERROR(FAILED)
  260. << "RegisterInputOutputInfo failed, inputs count " << inputs_count << " not match old count "
  261. << servable_meta.inputs_count << ",servable name " << servable_name;
  262. }
  263. if (servable_meta.outputs_count != 0 && servable_meta.outputs_count != outputs_count) {
  264. return INFER_STATUS_LOG_ERROR(FAILED)
  265. << "RegisterInputOutputInfo failed, outputs count " << outputs_count << " not match old count "
  266. << servable_meta.outputs_count << ",servable name " << servable_name;
  267. }
  268. servable_meta.inputs_count = inputs_count;
  269. servable_meta.outputs_count = outputs_count;
  270. return SUCCESS;
  271. }
  272. std::vector<size_t> ServableStorage::GetInputOutputInfo(const std::string &servable_name) const {
  273. std::vector<size_t> result;
  274. auto it = servable_signatures_map_.find(servable_name);
  275. if (it == servable_signatures_map_.end()) {
  276. return result;
  277. }
  278. result.push_back(it->second.servable_meta.inputs_count);
  279. result.push_back(it->second.servable_meta.outputs_count);
  280. return result;
  281. }
  282. void ServableStorage::Clear() { servable_signatures_map_.clear(); }
  283. } // namespace mindspore::serving

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.