You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

servable.cc 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/servable.h"
  17. #include <set>
  18. #include <sstream>
  19. #include "worker/preprocess.h"
  20. #include "worker/postprocess.h"
  21. namespace mindspore::serving {
  22. std::string ServableMeta::Repr() const {
  23. std::ostringstream stream;
  24. stream << "path(" << servable_name << ") file(" << servable_file + ")";
  25. return stream.str();
  26. }
  27. void ServableMeta::SetModelFormat(const std::string &format) {
  28. if (format == "om") {
  29. model_format = kOM;
  30. } else if (format == "mindir") {
  31. model_format = kMindIR;
  32. } else {
  33. MSI_LOG_ERROR << "Invalid model format " << format;
  34. }
  35. }
  36. std::string LoadServableSpec::Repr() const {
  37. std::string version;
  38. if (version_number > 0) {
  39. version = " version(" + std::to_string(version_number) + ") ";
  40. }
  41. return "servable(" + servable_name + ") " + version;
  42. }
  43. std::string WorkerSpec::Repr() const {
  44. std::string version;
  45. if (version_number > 0) {
  46. version = " version(" + std::to_string(version_number) + ") ";
  47. }
  48. return "servable(" + servable_name + ") " + version + " address(" + worker_address + ") ";
  49. }
  50. std::string RequestSpec::Repr() const {
  51. std::string version;
  52. if (version_number > 0) {
  53. version = " version(" + std::to_string(version_number) + ") ";
  54. }
  55. return "servable(" + servable_name + ") " + "method(" + method_name + ") " + version;
  56. }
  57. Status ServableSignature::Check() const {
  58. std::set<std::string> method_set;
  59. std::string model_str = servable_meta.Repr();
  60. for (auto &method : methods) {
  61. if (method_set.count(method.method_name) > 0) {
  62. return INFER_STATUS_LOG_ERROR(FAILED)
  63. << "Model " << model_str << " " << method.method_name << " has been defined repeatly";
  64. }
  65. method_set.emplace(method.method_name);
  66. size_t preprocess_outputs_count = 0;
  67. size_t postprocess_outputs_count = 0;
  68. const auto &preprocess_name = method.preprocess_name;
  69. if (!preprocess_name.empty()) {
  70. auto preprocess = PreprocessStorage::Instance().GetPreprocess(preprocess_name);
  71. if (preprocess == nullptr) {
  72. return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name
  73. << " preprocess " << preprocess_name << " not defined";
  74. }
  75. preprocess_outputs_count = preprocess->GetOutputsCount(preprocess_name);
  76. for (size_t i = 0; i < method.preprocess_inputs.size(); i++) {
  77. auto &input = method.preprocess_inputs[i];
  78. if (input.first != kPredictPhaseTag_Input) {
  79. return INFER_STATUS_LOG_ERROR(FAILED)
  80. << "Model " << model_str << " method " << method.method_name << ", the preprocess " << i
  81. << "th input phase tag " << input.first << " invalid";
  82. }
  83. if (input.second >= method.inputs.size()) {
  84. return INFER_STATUS_LOG_ERROR(FAILED)
  85. << "Model " << model_str << " method " << method.method_name << ", the preprocess " << i
  86. << "th input uses method " << input.second << "th input, that is greater than the method inputs size "
  87. << method.inputs.size();
  88. }
  89. }
  90. }
  91. for (size_t i = 0; i < method.servable_inputs.size(); i++) {
  92. auto &input = method.servable_inputs[i];
  93. if (input.first == kPredictPhaseTag_Input) {
  94. if (input.second >= method.inputs.size()) {
  95. return INFER_STATUS_LOG_ERROR(FAILED)
  96. << "Model " << model_str << " method " << method.method_name << ", the servable " << i
  97. << "th input uses method " << input.second << "th input, that is greater than the method inputs size "
  98. << method.inputs.size();
  99. }
  100. } else if (input.first == kPredictPhaseTag_Preproces) {
  101. if (input.second >= preprocess_outputs_count) {
  102. return INFER_STATUS_LOG_ERROR(FAILED)
  103. << "Model " << model_str << " method " << method.method_name << ", the servable " << i
  104. << "th input uses preprocess " << input.second
  105. << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count;
  106. }
  107. } else {
  108. return INFER_STATUS_LOG_ERROR(FAILED)
  109. << "Model " << model_str << " method " << method.method_name << ", the servable " << i
  110. << "th input phase tag " << input.first << " invalid";
  111. }
  112. }
  113. const auto &postprocess_name = method.postprocess_name;
  114. if (!method.postprocess_name.empty()) {
  115. auto postprocess = PostprocessStorage::Instance().GetPostprocess(postprocess_name);
  116. if (postprocess == nullptr) {
  117. return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name
  118. << " postprocess " << postprocess_name << " not defined";
  119. }
  120. postprocess_outputs_count = postprocess->GetOutputsCount(postprocess_name);
  121. for (size_t i = 0; i < method.postprocess_inputs.size(); i++) {
  122. auto &input = method.postprocess_inputs[i];
  123. if (input.first == kPredictPhaseTag_Input) {
  124. if (input.second >= method.inputs.size()) {
  125. return INFER_STATUS_LOG_ERROR(FAILED)
  126. << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i
  127. << "th input uses method " << input.second
  128. << "th input, that is greater than the method inputs size " << method.inputs.size();
  129. }
  130. } else if (input.first == kPredictPhaseTag_Preproces) {
  131. if (input.second >= preprocess_outputs_count) {
  132. return INFER_STATUS_LOG_ERROR(FAILED)
  133. << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i
  134. << "th input uses preprocess " << input.second
  135. << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count;
  136. }
  137. } else if (input.first == kPredictPhaseTag_Predict) {
  138. if (input.second >= servable_meta.outputs_count) {
  139. return INFER_STATUS_LOG_ERROR(FAILED)
  140. << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i
  141. << "th input uses servable " << input.second
  142. << "th output, that is greater than the servable outputs size " << servable_meta.outputs_count;
  143. }
  144. } else {
  145. return INFER_STATUS_LOG_ERROR(FAILED)
  146. << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i
  147. << "th input phase tag " << input.first << " invalid";
  148. }
  149. }
  150. }
  151. for (size_t i = 0; i < method.returns.size(); i++) {
  152. auto &input = method.returns[i];
  153. if (input.first == kPredictPhaseTag_Input) {
  154. if (input.second >= method.inputs.size()) {
  155. return INFER_STATUS_LOG_ERROR(FAILED)
  156. << "Model " << model_str << " method " << method.method_name << ", the method " << i
  157. << "th output uses method " << input.second << "th input, that is greater than the method inputs size "
  158. << method.inputs.size();
  159. }
  160. } else if (input.first == kPredictPhaseTag_Preproces) {
  161. if (input.second >= preprocess_outputs_count) {
  162. return INFER_STATUS_LOG_ERROR(FAILED)
  163. << "Model " << model_str << " method " << method.method_name << ", the method " << i
  164. << "th output uses preprocess " << input.second
  165. << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count;
  166. }
  167. } else if (input.first == kPredictPhaseTag_Predict) {
  168. if (input.second >= servable_meta.outputs_count) {
  169. return INFER_STATUS_LOG_ERROR(FAILED)
  170. << "Model " << model_str << " method " << method.method_name << ", the method " << i
  171. << "th output uses servable " << input.second
  172. << "th output, that is greater than the servable outputs size " << servable_meta.outputs_count;
  173. }
  174. } else if (input.first == kPredictPhaseTag_Postprocess) {
  175. if (input.second >= postprocess_outputs_count) {
  176. return INFER_STATUS_LOG_ERROR(FAILED)
  177. << "Model " << model_str << " method " << method.method_name << ", the method " << i
  178. << "th output uses postprocess " << input.second
  179. << "th output, that is greater than the postprocess outputs size " << postprocess_outputs_count;
  180. }
  181. } else {
  182. return INFER_STATUS_LOG_ERROR(FAILED)
  183. << "Model " << model_str << " method " << method.method_name << ", the method " << i
  184. << "th output phase tag " << input.first << " invalid";
  185. }
  186. }
  187. }
  188. return SUCCESS;
  189. }
  190. bool ServableSignature::GetMethodDeclare(const std::string &method_name, MethodSignature *method) {
  191. MSI_EXCEPTION_IF_NULL(method);
  192. auto item =
  193. find_if(methods.begin(), methods.end(), [&](const MethodSignature &v) { return v.method_name == method_name; });
  194. if (item != methods.end()) {
  195. *method = *item;
  196. return true;
  197. }
  198. return false;
  199. }
  200. void ServableStorage::Register(const ServableSignature &def) {
  201. auto model_name = def.servable_meta.servable_name;
  202. if (servable_signatures_map_.find(model_name) == servable_signatures_map_.end()) {
  203. MSI_LOG_WARNING << "Servable " << model_name << " has already been defined";
  204. }
  205. servable_signatures_map_[model_name] = def;
  206. }
  207. bool ServableStorage::GetServableDef(const std::string &model_name, ServableSignature *def) const {
  208. MSI_EXCEPTION_IF_NULL(def);
  209. auto it = servable_signatures_map_.find(model_name);
  210. if (it == servable_signatures_map_.end()) {
  211. return false;
  212. }
  213. *def = it->second;
  214. return true;
  215. }
  216. std::shared_ptr<ServableStorage> ServableStorage::Instance() {
  217. static std::shared_ptr<ServableStorage> storage;
  218. if (storage == nullptr) {
  219. storage = std::make_shared<ServableStorage>();
  220. }
  221. return storage;
  222. }
  223. void ServableStorage::RegisterMethod(const MethodSignature &method) {
  224. MSI_LOG_INFO << "Declare method " << method.method_name << ", servable " << method.servable_name;
  225. auto it = servable_signatures_map_.find(method.servable_name);
  226. if (it == servable_signatures_map_.end()) {
  227. ServableSignature signature;
  228. signature.methods.push_back(method);
  229. servable_signatures_map_[method.servable_name] = signature;
  230. return;
  231. }
  232. it->second.methods.push_back(method);
  233. }
  234. void ServableStorage::DeclareServable(const mindspore::serving::ServableMeta &servable) {
  235. MSI_LOG_INFO << "Declare servable " << servable.servable_name;
  236. auto it = servable_signatures_map_.find(servable.servable_name);
  237. if (it == servable_signatures_map_.end()) {
  238. ServableSignature signature;
  239. signature.servable_meta = servable;
  240. servable_signatures_map_[servable.servable_name] = signature;
  241. return;
  242. }
  243. it->second.servable_meta = servable;
  244. }
  245. void ServableStorage::RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count,
  246. size_t outputs_count) {
  247. auto it = servable_signatures_map_.find(servable_name);
  248. if (it == servable_signatures_map_.end()) {
  249. MSI_LOG_EXCEPTION << "RegisterInputOutputInfo failed, cannot find servable " << servable_name;
  250. }
  251. auto &servable_meta = it->second.servable_meta;
  252. if (servable_meta.inputs_count != 0 && servable_meta.inputs_count != inputs_count) {
  253. MSI_LOG_EXCEPTION << "RegisterInputOutputInfo failed, inputs count " << inputs_count << " not match old count "
  254. << servable_meta.inputs_count << ",servable name " << servable_name;
  255. }
  256. if (servable_meta.outputs_count != 0 && servable_meta.outputs_count != outputs_count) {
  257. MSI_LOG_EXCEPTION << "RegisterInputOutputInfo failed, outputs count " << outputs_count << " not match old count "
  258. << servable_meta.outputs_count << ",servable name " << servable_name;
  259. }
  260. servable_meta.inputs_count = inputs_count;
  261. servable_meta.outputs_count = outputs_count;
  262. }
  263. std::vector<size_t> ServableStorage::GetInputOutputInfo(const std::string &servable_name) const {
  264. std::vector<size_t> result;
  265. auto it = servable_signatures_map_.find(servable_name);
  266. if (it == servable_signatures_map_.end()) {
  267. return result;
  268. }
  269. result.push_back(it->second.servable_meta.inputs_count);
  270. result.push_back(it->second.servable_meta.outputs_count);
  271. return result;
  272. }
  273. } // namespace mindspore::serving

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.