You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

servable.cc 18 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/servable.h"
  17. #include <set>
  18. #include <sstream>
  19. #include <map>
  20. #include "worker/preprocess.h"
  21. #include "worker/postprocess.h"
  22. namespace mindspore::serving {
  23. std::string ServableMeta::Repr() const {
  24. std::ostringstream stream;
  25. switch (servable_type) {
  26. case kServableTypeUnknown:
  27. stream << "undeclared servable, servable name: '" << common_meta.servable_name << "'";
  28. break;
  29. case kServableTypeLocal:
  30. stream << "local servable, servable name: '" << common_meta.servable_name << "', file: '"
  31. << local_meta.servable_file + "'";
  32. break;
  33. case kServableTypeDistributed:
  34. stream << "distributed servable, servable name: '" << common_meta.servable_name
  35. << "', rank size: " << distributed_meta.rank_size << ", stage size " << distributed_meta.stage_size;
  36. break;
  37. }
  38. return stream.str();
  39. }
  40. void LocalServableMeta::SetModelFormat(const std::string &format) {
  41. if (format == "om") {
  42. model_format = kOM;
  43. } else if (format == "mindir") {
  44. model_format = kMindIR;
  45. } else {
  46. MSI_LOG_ERROR << "Invalid model format " << format;
  47. }
  48. }
  49. std::string LoadServableSpec::Repr() const {
  50. std::string version;
  51. if (version_number > 0) {
  52. version = " version(" + std::to_string(version_number) + ") ";
  53. }
  54. return "servable(" + servable_name + ") " + version;
  55. }
  56. std::string WorkerSpec::Repr() const {
  57. std::string version;
  58. if (version_number > 0) {
  59. version = " version(" + std::to_string(version_number) + ") ";
  60. }
  61. return "servable(" + servable_name + ") " + version + " address(" + worker_address + ") ";
  62. }
  63. std::string RequestSpec::Repr() const {
  64. std::string version;
  65. if (version_number > 0) {
  66. version = " version(" + std::to_string(version_number) + ") ";
  67. }
  68. return "servable(" + servable_name + ") " + "method(" + method_name + ") " + version;
  69. }
  70. Status ServableSignature::CheckPreprocessInput(const MethodSignature &method, size_t *preprocess_outputs_count) const {
  71. std::string model_str = servable_meta.Repr();
  72. const auto &preprocess_name = method.preprocess_name;
  73. if (!preprocess_name.empty()) {
  74. auto preprocess = PreprocessStorage::Instance().GetPreprocess(preprocess_name);
  75. if (preprocess == nullptr) {
  76. return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name
  77. << " preprocess " << preprocess_name << " not defined";
  78. }
  79. *preprocess_outputs_count = preprocess->GetOutputsCount(preprocess_name);
  80. for (size_t i = 0; i < method.preprocess_inputs.size(); i++) {
  81. auto &input = method.preprocess_inputs[i];
  82. if (input.first != kPredictPhaseTag_Input) {
  83. return INFER_STATUS_LOG_ERROR(FAILED)
  84. << "Model " << model_str << " method " << method.method_name << ", the data of preprocess " << i
  85. << "th input cannot not come from '" << input.first << "'";
  86. }
  87. if (input.second >= method.inputs.size()) {
  88. return INFER_STATUS_LOG_ERROR(FAILED)
  89. << "Model " << model_str << " method " << method.method_name << ", the preprocess " << i
  90. << "th input uses method " << input.second << "th input, that is greater than the method inputs size "
  91. << method.inputs.size();
  92. }
  93. }
  94. }
  95. return SUCCESS;
  96. }
  97. Status ServableSignature::CheckPredictInput(const MethodSignature &method, size_t preprocess_outputs_count) const {
  98. std::string model_str = servable_meta.Repr();
  99. for (size_t i = 0; i < method.servable_inputs.size(); i++) {
  100. auto &input = method.servable_inputs[i];
  101. if (input.first == kPredictPhaseTag_Input) {
  102. if (input.second >= method.inputs.size()) {
  103. return INFER_STATUS_LOG_ERROR(FAILED)
  104. << "Model " << model_str << " method " << method.method_name << ", the servable " << i
  105. << "th input uses method " << input.second << "th input, that is greater than the method inputs size "
  106. << method.inputs.size();
  107. }
  108. } else if (input.first == kPredictPhaseTag_Preproces) {
  109. if (input.second >= preprocess_outputs_count) {
  110. return INFER_STATUS_LOG_ERROR(FAILED)
  111. << "Model " << model_str << " method " << method.method_name << ", the servable " << i
  112. << "th input uses preprocess " << input.second
  113. << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count;
  114. }
  115. } else {
  116. return INFER_STATUS_LOG_ERROR(FAILED)
  117. << "Model " << model_str << " method " << method.method_name << ", the data of servable " << i
  118. << "th input cannot not come from '" << input.first << "'";
  119. }
  120. }
  121. return SUCCESS;
  122. }
  123. Status ServableSignature::CheckPostprocessInput(const MethodSignature &method, size_t preprocess_outputs_count,
  124. size_t *postprocess_outputs_count) const {
  125. std::string model_str = servable_meta.Repr();
  126. const auto &common_meta = servable_meta.common_meta;
  127. const auto &postprocess_name = method.postprocess_name;
  128. if (!method.postprocess_name.empty()) {
  129. auto postprocess = PostprocessStorage::Instance().GetPostprocess(postprocess_name);
  130. if (postprocess == nullptr) {
  131. return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name
  132. << " postprocess " << postprocess_name << " not defined";
  133. }
  134. *postprocess_outputs_count = postprocess->GetOutputsCount(postprocess_name);
  135. for (size_t i = 0; i < method.postprocess_inputs.size(); i++) {
  136. auto &input = method.postprocess_inputs[i];
  137. if (input.first == kPredictPhaseTag_Input) {
  138. if (input.second >= method.inputs.size()) {
  139. return INFER_STATUS_LOG_ERROR(FAILED)
  140. << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i
  141. << "th input uses method " << input.second << "th input, that is greater than the method inputs size "
  142. << method.inputs.size();
  143. }
  144. } else if (input.first == kPredictPhaseTag_Preproces) {
  145. if (input.second >= preprocess_outputs_count) {
  146. return INFER_STATUS_LOG_ERROR(FAILED)
  147. << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i
  148. << "th input uses preprocess " << input.second
  149. << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count;
  150. }
  151. } else if (input.first == kPredictPhaseTag_Predict) {
  152. if (input.second >= common_meta.outputs_count) {
  153. return INFER_STATUS_LOG_ERROR(FAILED)
  154. << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i
  155. << "th input uses servable " << input.second
  156. << "th output, that is greater than the servable outputs size " << common_meta.outputs_count;
  157. }
  158. } else {
  159. return INFER_STATUS_LOG_ERROR(FAILED)
  160. << "Model " << model_str << " method " << method.method_name << ", the data of postprocess " << i
  161. << "th input cannot not come from '" << input.first << "'";
  162. }
  163. }
  164. }
  165. return SUCCESS;
  166. }
  167. Status ServableSignature::CheckReturn(const MethodSignature &method, size_t preprocess_outputs_count,
  168. size_t postprocess_outputs_count) const {
  169. std::string model_str = servable_meta.Repr();
  170. const auto &common_meta = servable_meta.common_meta;
  171. for (size_t i = 0; i < method.returns.size(); i++) {
  172. auto &input = method.returns[i];
  173. if (input.first == kPredictPhaseTag_Input) {
  174. if (input.second >= method.inputs.size()) {
  175. return INFER_STATUS_LOG_ERROR(FAILED)
  176. << "Model " << model_str << " method " << method.method_name << ", the method " << i
  177. << "th output uses method " << input.second << "th input, that is greater than the method inputs size "
  178. << method.inputs.size();
  179. }
  180. } else if (input.first == kPredictPhaseTag_Preproces) {
  181. if (input.second >= preprocess_outputs_count) {
  182. return INFER_STATUS_LOG_ERROR(FAILED)
  183. << "Model " << model_str << " method " << method.method_name << ", the method " << i
  184. << "th output uses preprocess " << input.second
  185. << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count;
  186. }
  187. } else if (input.first == kPredictPhaseTag_Predict) {
  188. if (input.second >= common_meta.outputs_count) {
  189. return INFER_STATUS_LOG_ERROR(FAILED)
  190. << "Model " << model_str << " method " << method.method_name << ", the method " << i
  191. << "th output uses servable " << input.second
  192. << "th output, that is greater than the servable outputs size " << common_meta.outputs_count;
  193. }
  194. } else if (input.first == kPredictPhaseTag_Postprocess) {
  195. if (input.second >= postprocess_outputs_count) {
  196. return INFER_STATUS_LOG_ERROR(FAILED)
  197. << "Model " << model_str << " method " << method.method_name << ", the method " << i
  198. << "th output uses postprocess " << input.second
  199. << "th output, that is greater than the postprocess outputs size " << postprocess_outputs_count;
  200. }
  201. } else {
  202. return INFER_STATUS_LOG_ERROR(FAILED)
  203. << "Model " << model_str << " method " << method.method_name << ", the data of method " << i
  204. << "th output cannot not come from '" << input.first << "'";
  205. }
  206. }
  207. return SUCCESS;
  208. }
  209. Status ServableSignature::Check() const {
  210. std::set<std::string> method_set;
  211. Status status;
  212. for (auto &method : methods) {
  213. if (method_set.count(method.method_name) > 0) {
  214. return INFER_STATUS_LOG_ERROR(FAILED)
  215. << "Model " << servable_meta.Repr() << " " << method.method_name << " has been defined repeatedly";
  216. }
  217. method_set.emplace(method.method_name);
  218. size_t preprocess_outputs_count = 0;
  219. size_t postprocess_outputs_count = 0;
  220. status = CheckPreprocessInput(method, &preprocess_outputs_count);
  221. if (status != SUCCESS) {
  222. return status;
  223. }
  224. status = CheckPredictInput(method, preprocess_outputs_count);
  225. if (status != SUCCESS) {
  226. return status;
  227. }
  228. status = CheckPostprocessInput(method, preprocess_outputs_count, &postprocess_outputs_count);
  229. if (status != SUCCESS) {
  230. return status;
  231. }
  232. status = CheckReturn(method, preprocess_outputs_count, postprocess_outputs_count);
  233. if (status != SUCCESS) {
  234. return status;
  235. }
  236. }
  237. return SUCCESS;
  238. }
  239. bool ServableSignature::GetMethodDeclare(const std::string &method_name, MethodSignature *method) {
  240. MSI_EXCEPTION_IF_NULL(method);
  241. auto item =
  242. find_if(methods.begin(), methods.end(), [&](const MethodSignature &v) { return v.method_name == method_name; });
  243. if (item != methods.end()) {
  244. *method = *item;
  245. return true;
  246. }
  247. return false;
  248. }
  249. void ServableStorage::Register(const ServableSignature &def) {
  250. auto model_name = def.servable_meta.common_meta.servable_name;
  251. if (servable_signatures_map_.find(model_name) == servable_signatures_map_.end()) {
  252. MSI_LOG_WARNING << "Servable " << model_name << " has already been defined";
  253. }
  254. servable_signatures_map_[model_name] = def;
  255. }
  256. bool ServableStorage::GetServableDef(const std::string &model_name, ServableSignature *def) const {
  257. MSI_EXCEPTION_IF_NULL(def);
  258. auto it = servable_signatures_map_.find(model_name);
  259. if (it == servable_signatures_map_.end()) {
  260. return false;
  261. }
  262. *def = it->second;
  263. return true;
  264. }
  265. ServableStorage &ServableStorage::Instance() {
  266. static ServableStorage storage;
  267. return storage;
  268. }
  269. Status ServableStorage::RegisterMethod(const MethodSignature &method) {
  270. MSI_LOG_INFO << "Declare method " << method.method_name << ", servable " << method.servable_name;
  271. auto it = servable_signatures_map_.find(method.servable_name);
  272. if (it == servable_signatures_map_.end()) {
  273. ServableSignature signature;
  274. signature.methods.push_back(method);
  275. servable_signatures_map_[method.servable_name] = signature;
  276. return SUCCESS;
  277. }
  278. for (auto &item : it->second.methods) {
  279. // cppcheck-suppress useStlAlgorithm
  280. if (item.method_name == method.method_name) {
  281. return INFER_STATUS_LOG_ERROR(FAILED)
  282. << "Method " << method.method_name << " has been registered more than once.";
  283. }
  284. }
  285. it->second.methods.push_back(method);
  286. return SUCCESS;
  287. }
  288. Status ServableStorage::DeclareServable(ServableMeta servable) {
  289. auto &common_meta = servable.common_meta;
  290. MSI_LOG_INFO << "Declare servable " << common_meta.servable_name;
  291. servable.servable_type = kServableTypeLocal;
  292. if (servable.local_meta.servable_file.empty()) {
  293. return INFER_STATUS_LOG_ERROR(FAILED)
  294. << "Declare servable " << common_meta.servable_name << " failed, servable_file cannot be empty";
  295. }
  296. if (servable.local_meta.model_format == ModelType::kUnknownType) {
  297. return INFER_STATUS_LOG_ERROR(FAILED)
  298. << "Declare servable " << common_meta.servable_name << " failed, model_format is not inited";
  299. }
  300. auto it = servable_signatures_map_.find(common_meta.servable_name);
  301. if (it == servable_signatures_map_.end()) {
  302. ServableSignature signature;
  303. signature.servable_meta = servable;
  304. servable_signatures_map_[common_meta.servable_name] = signature;
  305. return SUCCESS;
  306. }
  307. auto &org_servable_meta = it->second.servable_meta;
  308. if (org_servable_meta.servable_type != kServableTypeUnknown) {
  309. return INFER_STATUS_LOG_ERROR(FAILED)
  310. << "Servable " << common_meta.servable_name << " has already been declared as: " << servable.Repr();
  311. }
  312. org_servable_meta = servable;
  313. return SUCCESS;
  314. }
  315. Status ServableStorage::DeclareDistributedServable(ServableMeta servable) {
  316. auto &common_meta = servable.common_meta;
  317. MSI_LOG_INFO << "Declare servable " << common_meta.servable_name;
  318. servable.servable_type = kServableTypeDistributed;
  319. if (servable.distributed_meta.rank_size == 0) {
  320. return INFER_STATUS_LOG_ERROR(FAILED)
  321. << "Declare distributed servable " << common_meta.servable_name << " failed, rank_size cannot be 0";
  322. }
  323. if (servable.distributed_meta.stage_size == 0) {
  324. return INFER_STATUS_LOG_ERROR(FAILED)
  325. << "Declare distributed servable " << common_meta.servable_name << " failed, stage_size cannot be 0";
  326. }
  327. auto it = servable_signatures_map_.find(common_meta.servable_name);
  328. if (it == servable_signatures_map_.end()) {
  329. ServableSignature signature;
  330. signature.servable_meta = servable;
  331. servable_signatures_map_[common_meta.servable_name] = signature;
  332. return SUCCESS;
  333. }
  334. auto &org_servable_meta = it->second.servable_meta;
  335. if (org_servable_meta.servable_type != kServableTypeUnknown) {
  336. return INFER_STATUS_LOG_ERROR(FAILED)
  337. << "Servable " << common_meta.servable_name << " has already been declared as: " << servable.Repr();
  338. }
  339. org_servable_meta = servable;
  340. return SUCCESS;
  341. }
  342. Status ServableStorage::RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count,
  343. size_t outputs_count) {
  344. auto it = servable_signatures_map_.find(servable_name);
  345. if (it == servable_signatures_map_.end()) {
  346. return INFER_STATUS_LOG_ERROR(FAILED) << "RegisterInputOutputInfo failed, cannot find servable " << servable_name;
  347. }
  348. auto &servable_meta = it->second.servable_meta;
  349. auto &common_meta = servable_meta.common_meta;
  350. if (common_meta.inputs_count != 0 && common_meta.inputs_count != inputs_count) {
  351. return INFER_STATUS_LOG_ERROR(FAILED)
  352. << "RegisterInputOutputInfo failed, inputs count " << inputs_count << " not match old count "
  353. << common_meta.inputs_count << ",servable name " << servable_name;
  354. }
  355. if (common_meta.outputs_count != 0 && common_meta.outputs_count != outputs_count) {
  356. return INFER_STATUS_LOG_ERROR(FAILED)
  357. << "RegisterInputOutputInfo failed, outputs count " << outputs_count << " not match old count "
  358. << common_meta.outputs_count << ",servable name " << servable_name;
  359. }
  360. common_meta.inputs_count = inputs_count;
  361. common_meta.outputs_count = outputs_count;
  362. return SUCCESS;
  363. }
  364. std::vector<size_t> ServableStorage::GetInputOutputInfo(const std::string &servable_name) const {
  365. std::vector<size_t> result;
  366. auto it = servable_signatures_map_.find(servable_name);
  367. if (it == servable_signatures_map_.end()) {
  368. return result;
  369. }
  370. result.push_back(it->second.servable_meta.common_meta.inputs_count);
  371. result.push_back(it->second.servable_meta.common_meta.outputs_count);
  372. return result;
  373. }
  374. void ServableStorage::Clear() { servable_signatures_map_.clear(); }
  375. } // namespace mindspore::serving

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.