You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

method.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Method registration interface"""
  16. import inspect
  17. import ast
  18. from functools import wraps
  19. from easydict import EasyDict
  20. from mindspore_serving._mindspore_serving import ServableStorage_, MethodSignature_, PredictPhaseTag_
  21. from mindspore_serving.worker.common import get_func_name, get_servable_dir
  22. from mindspore_serving.worker import check_type
  23. from mindspore_serving import log as logger
  24. from .preprocess import register_preprocess, check_preprocess
  25. from .postprocess import register_postprocess, check_postprocess
  26. method_def_context_ = MethodSignature_()
  27. method_def_ast_meta_ = EasyDict()
  28. method_tag_input = PredictPhaseTag_.kPredictPhaseTag_Input
  29. method_tag_preprocess = PredictPhaseTag_.kPredictPhaseTag_Preproces
  30. method_tag_predict = PredictPhaseTag_.kPredictPhaseTag_Predict
  31. method_tag_postprocess = PredictPhaseTag_.kPredictPhaseTag_Postprocess
  32. class _ServableStorage:
  33. """Declare servable info"""
  34. def __init__(self):
  35. self.methods = {}
  36. self.servable_metas = {}
  37. self.storage = ServableStorage_.get_instance()
  38. def declare_servable(self, servable_meta):
  39. """Declare servable info excluding method, input and output count"""
  40. self.storage.declare_servable(servable_meta)
  41. self.servable_metas[servable_meta.servable_name] = servable_meta
  42. def declare_servable_input_output(self, servable_name, inputs_count, outputs_count):
  43. """Declare input and output count of servable"""
  44. self.storage.register_servable_input_output_info(servable_name, inputs_count, outputs_count)
  45. servable_meta = self.servable_metas[servable_name]
  46. servable_meta.inputs_count = inputs_count
  47. servable_meta.outputs_count = outputs_count
  48. def register_method(self, method_signature):
  49. """Declare method of servable"""
  50. if method_signature.method_name in self.methods:
  51. raise RuntimeError(f"Method {method_signature.method_name} has been registered more than once.")
  52. self.storage.register_method(method_signature)
  53. self.methods[method_signature.method_name] = method_signature
  54. def get_method(self, method_name):
  55. method = self.methods.get(method_name, None)
  56. if method is None:
  57. raise RuntimeError(f"Method '{method_name}' not found")
  58. return method
  59. def get_servable_meta(self, servable_name):
  60. servable = self.servable_metas.get(servable_name, None)
  61. if servable is None:
  62. raise RuntimeError(f"Servable '{servable_name}' not found")
  63. return servable
  64. _servable_storage = _ServableStorage()
  65. class _TensorDef:
  66. """Data flow item, for definitions of data flow in a method"""
  67. def __init__(self, tag, tensor_index):
  68. self.tag = tag
  69. self.tensor_index = tensor_index
  70. def as_pair(self):
  71. return (self.tag, self.tensor_index)
  72. def _create_tensor_def_outputs(tag, outputs_cnt):
  73. """Create data flow item for output"""
  74. result = [_TensorDef(tag, i) for i in range(outputs_cnt)]
  75. if len(result) == 1:
  76. return result[0]
  77. return tuple(result)
  78. def _wrap_fun_to_pipeline(fun, input_count):
  79. """wrap preprocess and postprocess to pipeline"""
  80. argspec_len = len(inspect.signature(fun).parameters)
  81. if argspec_len != input_count:
  82. raise RuntimeError(f"function {fun.__name__} input args count {argspec_len} not match "
  83. f"registered in method count {input_count}")
  84. @wraps(fun)
  85. def call_func(instances):
  86. for instance in instances:
  87. inputs = []
  88. for i in range(input_count):
  89. inputs.append(instance[i])
  90. yield fun(*inputs)
  91. return call_func
  92. def call_preprocess_pipeline(preprocess_fun, *args):
  93. r"""For method registration, define the preprocessing pipeline function and its' parameters.
  94. Args:
  95. preprocess_fun (function): Python pipeline function for preprocess.
  96. args: Preprocess inputs. The length of 'args' should equal to the input parameters number
  97. of implemented python function.
  98. Raises:
  99. RuntimeError: The type or value of the parameters is invalid, or other error happened.
  100. Examples:
  101. >>> from mindspore_serving.worker import register
  102. >>> import numpy as np
  103. >>> def add_trans_datatype(x1, x2):
  104. ... return x1.astype(np.float32), x2.astype(np.float32)
  105. >>>
  106. >>> register.declare_servable(servable_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False)
  107. >>>
  108. >>> @register.register_method(output_names=["y"]) # register add_cast method in add
  109. >>> def add_cast(x1, x2):
  110. ... x1, x2 = register.call_preprocess(add_trans_datatype, x1, x2) # cast input to float32
  111. ... y = register.call_servable(x1, x2)
  112. ... return y
  113. """
  114. global method_def_context_
  115. if method_def_context_.preprocess_name:
  116. raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', "
  117. f"call_preprocess or call_preprocess_pipeline should not be invoked more than once")
  118. if method_def_context_.servable_name:
  119. raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', "
  120. f"call_servable should be invoked after call_preprocess_pipeline")
  121. if method_def_context_.postprocess_name:
  122. raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', call_postprocess "
  123. f"or call_postprocess_pipeline should be invoked after call_preprocess_pipeline")
  124. if _call_preprocess_pipeline_name not in method_def_ast_meta_:
  125. raise RuntimeError(f"Invalid call of '{_call_preprocess_pipeline_name}'")
  126. inputs_count, outputs_count = method_def_ast_meta_[_call_preprocess_pipeline_name]
  127. preprocess_name = preprocess_fun
  128. if inspect.isfunction(preprocess_fun):
  129. register_preprocess(preprocess_fun, inputs_count=inputs_count, outputs_count=outputs_count)
  130. preprocess_name = get_servable_dir() + "." + get_func_name(preprocess_fun)
  131. else:
  132. if not isinstance(preprocess_name, str):
  133. raise RuntimeError(
  134. f"Check failed in method '{method_def_context_.method_name}', "
  135. f"call_preprocess first must be function or str, now is {type(preprocess_name)}")
  136. check_preprocess(preprocess_name, inputs_count=inputs_count, outputs_count=outputs_count)
  137. method_def_context_.preprocess_name = preprocess_name
  138. method_def_context_.preprocess_inputs = [item.as_pair() for item in args]
  139. return _create_tensor_def_outputs(method_tag_preprocess, outputs_count)
  140. def call_preprocess(preprocess_fun, *args):
  141. r"""For method registration, define the preprocessing function and its' parameters.
  142. Args:
  143. preprocess_fun (function): Python function for preprocess.
  144. args: Preprocess inputs. The length of 'args' should equal to the input parameters number
  145. of implemented python function.
  146. Raises:
  147. RuntimeError: The type or value of the parameters is invalid, or other error happened.
  148. Examples:
  149. >>> from mindspore_serving.worker import register
  150. >>> import numpy as np
  151. >>> def add_trans_datatype(x1, x2):
  152. ... return x1.astype(np.float32), x2.astype(np.float32)
  153. >>>
  154. >>> register.declare_servable(servable_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False)
  155. >>>
  156. >>> @register.register_method(output_names=["y"]) # register add_cast method in add
  157. >>> def add_cast(x1, x2):
  158. ... x1, x2 = register.call_preprocess(add_trans_datatype, x1, x2) # cast input to float32
  159. ... y = register.call_servable(x1, x2)
  160. ... return y
  161. """
  162. global method_def_context_
  163. if method_def_context_.preprocess_name:
  164. raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', "
  165. f"call_preprocess or call_preprocess_pipeline should not be invoked more than once")
  166. if method_def_context_.servable_name:
  167. raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', "
  168. f"call_servable should be invoked after call_preprocess")
  169. if method_def_context_.postprocess_name:
  170. raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', "
  171. f"call_postprocess or call_postprocess_pipeline should be invoked after call_preprocess")
  172. if _call_preprocess_name not in method_def_ast_meta_:
  173. raise RuntimeError(f"Invalid call of '{_call_preprocess_name}'")
  174. inputs_count, outputs_count = method_def_ast_meta_[_call_preprocess_name]
  175. preprocess_name = preprocess_fun
  176. if inspect.isfunction(preprocess_fun):
  177. register_preprocess(_wrap_fun_to_pipeline(preprocess_fun, inputs_count),
  178. inputs_count=inputs_count, outputs_count=outputs_count)
  179. preprocess_name = get_servable_dir() + "." + get_func_name(preprocess_fun)
  180. else:
  181. if not isinstance(preprocess_name, str):
  182. raise RuntimeError(
  183. f"Check failed in method '{method_def_context_.method_name}', "
  184. f"call_preprocess first must be function or str, now is {type(preprocess_name)}")
  185. check_preprocess(preprocess_name, inputs_count=inputs_count, outputs_count=outputs_count)
  186. method_def_context_.preprocess_name = preprocess_name
  187. method_def_context_.preprocess_inputs = [item.as_pair() for item in args]
  188. return _create_tensor_def_outputs(method_tag_preprocess, outputs_count)
  189. def call_servable(*args):
  190. r"""For method registration, define the inputs data of model inference
  191. Note:
  192. The length of 'args' should be equal to the inputs number of model
  193. Args:
  194. args: Model's inputs, the length of 'args' should be equal to the inputs number of model.
  195. Raises:
  196. RuntimeError: The type or value of the parameters is invalid, or other error happened.
  197. Examples:
  198. >>> from mindspore_serving.worker import register
  199. >>> register.declare_servable(servable_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False)
  200. >>>
  201. >>> @register.register_method(output_names=["y"]) # register add_common method in add
  202. >>> def add_common(x1, x2):
  203. ... y = register.call_servable(x1, x2)
  204. ... return y
  205. """
  206. global method_def_context_
  207. if method_def_context_.servable_name:
  208. raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', "
  209. f"call_servable should not be invoked more than once")
  210. if method_def_context_.postprocess_name:
  211. raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', "
  212. f"call_postprocess or call_postprocess_pipeline should be invoked after call_servable")
  213. servable_name = get_servable_dir()
  214. inputs_count, outputs_count = method_def_ast_meta_[_call_servable_name]
  215. _servable_storage.declare_servable_input_output(servable_name, inputs_count, outputs_count)
  216. if inputs_count != len(args):
  217. raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', given servable input "
  218. f"size {len(args)} not match '{servable_name}' ast parse size {inputs_count}")
  219. method_def_context_.servable_name = servable_name
  220. method_def_context_.servable_inputs = [item.as_pair() for item in args]
  221. return _create_tensor_def_outputs(method_tag_predict, outputs_count)
  222. def call_postprocess_pipeline(postprocess_fun, *args):
  223. r"""For method registration, define the postprocessing pipeline function and its' parameters.
  224. Args:
  225. postprocess_fun (function): Python pipeline function for postprocess.
  226. args: Preprocess inputs. The length of 'args' should equal to the input parameters number
  227. of implemented python function.
  228. Raises:
  229. RuntimeError: The type or value of the parameters is invalid, or other error happened.
  230. """
  231. global method_def_context_
  232. if method_def_context_.postprocess_name:
  233. raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', "
  234. f"call_postprocess or call_postprocess_pipeline should not be invoked more than once")
  235. if _call_postprocess_pipeline_name not in method_def_ast_meta_:
  236. raise RuntimeError(f"Invalid call of '{_call_postprocess_pipeline_name}'")
  237. inputs_count, outputs_count = method_def_ast_meta_[_call_postprocess_pipeline_name]
  238. postprocess_name = postprocess_fun
  239. if inspect.isfunction(postprocess_fun):
  240. register_postprocess(postprocess_fun, inputs_count=inputs_count, outputs_count=outputs_count)
  241. postprocess_name = get_servable_dir() + "." + get_func_name(postprocess_fun)
  242. else:
  243. if not isinstance(postprocess_name, str):
  244. raise RuntimeError(
  245. f"Check failed in method '{method_def_context_.method_name}', "
  246. f"call_postprocess first must be function or str, now is {type(postprocess_name)}")
  247. check_postprocess(postprocess_name, inputs_count=inputs_count, outputs_count=outputs_count)
  248. method_def_context_.postprocess_name = postprocess_name
  249. method_def_context_.postprocess_inputs = [item.as_pair() for item in args]
  250. return _create_tensor_def_outputs(method_tag_postprocess, outputs_count)
  251. def call_postprocess(postprocess_fun, *args):
  252. r"""For method registration, define the postprocessing function and its' parameters.
  253. Args:
  254. postprocess_fun (function): Python function for postprocess.
  255. args: Preprocess inputs. The length of 'args' should equal to the input parameters number
  256. of implemented python function.
  257. Raises:
  258. RuntimeError: The type or value of the parameters is invalid, or other error happened.
  259. """
  260. global method_def_context_
  261. if method_def_context_.postprocess_name:
  262. raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', "
  263. f"call_postprocess or call_postprocess_pipeline should not be invoked more than once")
  264. if _call_postprocess_name not in method_def_ast_meta_:
  265. raise RuntimeError(f"Invalid call of '{_call_postprocess_name}'")
  266. inputs_count, outputs_count = method_def_ast_meta_[_call_postprocess_name]
  267. postprocess_name = postprocess_fun
  268. if inspect.isfunction(postprocess_fun):
  269. register_postprocess(_wrap_fun_to_pipeline(postprocess_fun, inputs_count),
  270. inputs_count=inputs_count, outputs_count=outputs_count)
  271. postprocess_name = get_servable_dir() + "." + get_func_name(postprocess_fun)
  272. else:
  273. if not isinstance(postprocess_name, str):
  274. raise RuntimeError(
  275. f"Check failed in method '{method_def_context_.method_name}', "
  276. f"call_postprocess first must be function or str, now is {type(postprocess_name)}")
  277. check_postprocess(postprocess_name, inputs_count=inputs_count, outputs_count=outputs_count)
  278. method_def_context_.postprocess_name = postprocess_name
  279. method_def_context_.postprocess_inputs = [item.as_pair() for item in args]
  280. return _create_tensor_def_outputs(method_tag_postprocess, outputs_count)
  281. _call_preprocess_name = call_preprocess.__name__
  282. _call_servable_name = call_servable.__name__
  283. _call_postprocess_name = call_postprocess.__name__
  284. _call_preprocess_pipeline_name = call_preprocess_pipeline.__name__
  285. _call_postprocess_pipeline_name = call_postprocess_pipeline.__name__
  286. def _get_method_def_func_meta(method_def_func):
  287. """Parse register_method func, and get the input and output count of preprocess, servable and postprocess"""
  288. source = inspect.getsource(method_def_func)
  289. call_list = ast.parse(source).body[0].body
  290. func_meta = EasyDict()
  291. for call_item in call_list:
  292. if not isinstance(call_item, ast.Assign):
  293. continue
  294. target = call_item.targets[0]
  295. if isinstance(target, ast.Name):
  296. outputs_count = 1
  297. elif isinstance(target, ast.Tuple):
  298. outputs_count = len(target.elts)
  299. else:
  300. continue
  301. call = call_item.value
  302. if not isinstance(call, ast.Call):
  303. continue
  304. func = call.func
  305. if isinstance(func, ast.Attribute):
  306. func_name = func.attr
  307. elif isinstance(func, ast.Name):
  308. func_name = func.id
  309. else:
  310. continue
  311. inputs_count = len(call.args)
  312. if func_name in (_call_preprocess_name, _call_preprocess_pipeline_name,
  313. _call_postprocess_name, _call_postprocess_pipeline_name):
  314. inputs_count -= 1
  315. elif func_name == _call_servable_name:
  316. pass
  317. else:
  318. continue
  319. if inputs_count <= 0:
  320. raise RuntimeError(f"Invalid '{func_name}' invoke args")
  321. logger.info(f"call type '{func_name}', inputs count {inputs_count}, outputs count {outputs_count}")
  322. func_meta[func_name] = [inputs_count, outputs_count]
  323. if _call_servable_name not in func_meta:
  324. raise RuntimeError(f"Not find the invoke of '{_call_servable_name}'")
  325. return func_meta
  326. def register_method(output_names):
  327. """register method for servable.
  328. Define the data flow of preprocess, model inference and postprocess in the method.
  329. Preprocess and postprocess are optional.
  330. Args:
  331. output_names (str, tuple or list of str): The output names of method. The input names is
  332. the args names of the registered function.
  333. Raises:
  334. RuntimeError: The type or value of the parameters is invalid, or other error happened.
  335. Examples:
  336. >>> from mindspore_serving.worker import register
  337. >>> import numpy as np
  338. >>> def add_trans_datatype(x1, x2):
  339. ... return x1.astype(np.float32), x2.astype(np.float32)
  340. >>>
  341. >>> register.declare_servable(servable_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False)
  342. >>>
  343. >>> @register.register_method(output_names=["y"]) # register add_cast method in add
  344. >>> def add_cast(x1, x2):
  345. ... x1, x2 = register.call_preprocess(add_trans_datatype, x1, x2) # cast input to float32
  346. ... y = register.call_servable(x1, x2)
  347. ... return y
  348. """
  349. output_names = check_type.check_and_as_str_tuple_list('output_names', output_names)
  350. def register(func):
  351. name = get_func_name(func)
  352. sig = inspect.signature(func)
  353. input_names = []
  354. for k, v in sig.parameters.items():
  355. if v.kind == inspect.Parameter.VAR_POSITIONAL:
  356. raise RuntimeError(f"'{name}' input {k} cannot be VAR_POSITIONAL !")
  357. if v.kind == inspect.Parameter.VAR_KEYWORD:
  358. raise RuntimeError(f"'{name}' input {k} cannot be VAR_KEYWORD !")
  359. input_names.append(k)
  360. input_tensors = []
  361. for i in range(len(input_names)):
  362. input_tensors.append(_TensorDef(method_tag_input, i))
  363. global method_def_context_
  364. method_def_context_ = MethodSignature_()
  365. method_def_context_.method_name = name
  366. method_def_context_.inputs = input_names
  367. method_def_context_.outputs = output_names
  368. global method_def_ast_meta_
  369. method_def_ast_meta_ = _get_method_def_func_meta(func)
  370. output_tensors = func(*tuple(input_tensors))
  371. if isinstance(output_tensors, _TensorDef):
  372. output_tensors = (output_tensors,)
  373. if len(output_tensors) != len(output_names):
  374. raise RuntimeError(
  375. f"Method return output size {len(output_tensors)} not match registered {len(output_names)}")
  376. method_def_context_.returns = [item.as_pair() for item in output_tensors]
  377. logger.info(f"Register method: method_name {method_def_context_.method_name} "
  378. f", servable_name {method_def_context_.servable_name}, inputs: {input_names}, outputs: "
  379. f"{output_names}")
  380. global _servable_storage
  381. _servable_storage.register_method(method_def_context_)
  382. return func
  383. return register

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.