You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.py 9.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """test Serving, Common"""
  16. import os
  17. from functools import wraps
  18. from shutil import rmtree
  19. from mindspore_serving import master
  20. from mindspore_serving import worker
  21. from mindspore_serving.client import Client
  22. servable_index = 0
  23. class ServingTestBase:
  24. def __init__(self):
  25. servable_dir = "serving_python_ut_servables"
  26. self.servable_dir = os.path.join(os.getcwd(), servable_dir)
  27. rmtree(self.servable_dir, True)
  28. def init_servable(self, version_number, config_file, model_file="tensor_add.mindir"):
  29. cur_dir = os.path.dirname(os.path.abspath(__file__))
  30. config_file_abs = os.path.join(os.path.join(cur_dir, "../servable_config/"), config_file)
  31. try:
  32. with open(config_file_abs, "r") as fp:
  33. servable_config_content = fp.read()
  34. except FileNotFoundError:
  35. servable_config_content = None
  36. self.init_servable_with_servable_config(version_number, servable_config_content, model_file)
  37. def init_servable_with_servable_config(self, version_number, servable_config_content,
  38. model_file="tensor_add.mindir"):
  39. global servable_index
  40. self.servable_name = "add_" + str(servable_index)
  41. servable_index += 1
  42. self.version_number = version_number
  43. self.model_file_name = model_file
  44. self.servable_name_path = os.path.join(self.servable_dir, self.servable_name)
  45. self.version_number_path = os.path.join(self.servable_name_path, str(version_number))
  46. self.model_file_name_path = os.path.join(self.version_number_path, model_file)
  47. try:
  48. os.mkdir(self.servable_dir)
  49. except FileExistsError:
  50. pass
  51. try:
  52. os.mkdir(self.servable_name_path)
  53. except FileExistsError:
  54. pass
  55. try:
  56. os.mkdir(self.version_number_path)
  57. except FileExistsError:
  58. pass
  59. with open(self.model_file_name_path, "w") as fp:
  60. print("model content", file=fp)
  61. if servable_config_content is not None:
  62. config_file = os.path.join(self.servable_name_path, "servable_config.py")
  63. with open(config_file, "w") as fp:
  64. fp.write(servable_config_content)
  65. client_create_list = []
  66. def serving_test(func):
  67. @wraps(func)
  68. def wrap_test(*args, **kwargs):
  69. try:
  70. func(*args, **kwargs)
  71. finally:
  72. master.stop()
  73. worker.stop()
  74. servable_dir = os.path.join(os.getcwd(), "serving_python_ut_servables")
  75. rmtree(servable_dir, True)
  76. global client_create_list
  77. for client in client_create_list:
  78. del client.stub
  79. client.stub = None
  80. client_create_list = []
  81. return wrap_test
  82. def create_client(ip, port, servable_name, method_name, version_number=0):
  83. client = Client(ip, port, servable_name, method_name, version_number)
  84. client_create_list.append(client)
  85. return client
  86. def release_client(client):
  87. del client.stub
  88. client.stub = None
  89. # test servable_config.py with client
  90. servable_config_import = r"""
  91. import numpy as np
  92. from mindspore_serving.worker import register
  93. """
  94. servable_config_declare_servable = r"""
  95. register.declare_servable(servable_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False)
  96. """
  97. servable_config_preprocess_cast = r"""
  98. def add_trans_datatype(x1, x2):
  99. return x1.astype(np.float32), x2.astype(np.float32)
  100. """
  101. servable_config_method_add_common = r"""
  102. @register.register_method(output_names=["y"])
  103. def add_common(x1, x2): # only support float32 inputs
  104. y = register.call_servable(x1, x2)
  105. return y
  106. """
  107. servable_config_method_add_cast = r"""
  108. @register.register_method(output_names=["y"])
  109. def add_cast(x1, x2):
  110. x1, x2 = register.call_preprocess(add_trans_datatype, x1, x2) # cast input to float32
  111. y = register.call_servable(x1, x2)
  112. return y
  113. """
  114. def init_add_servable():
  115. base = ServingTestBase()
  116. servable_content = servable_config_import
  117. servable_content += servable_config_declare_servable
  118. servable_content += servable_config_preprocess_cast
  119. servable_content += servable_config_method_add_common
  120. servable_content += servable_config_method_add_cast
  121. base.init_servable_with_servable_config(1, servable_content)
  122. return base
  123. def init_str_servable():
  124. base = ServingTestBase()
  125. servable_content = servable_config_import
  126. servable_content += servable_config_declare_servable
  127. servable_content += r"""
  128. def preprocess(other):
  129. return np.ones([2,2], np.float32), np.ones([2,2], np.float32)
  130. def str_concat_postprocess(text1, text2):
  131. return text1 + text2
  132. @register.register_method(output_names=["text"])
  133. def str_concat(text1, text2):
  134. x1, x2 = register.call_preprocess(preprocess, text1)
  135. y = register.call_servable(x1, x2)
  136. text = register.call_postprocess(str_concat_postprocess, text1, text2)
  137. return text
  138. def str_empty_postprocess(text1, text2):
  139. if len(text1) == 0:
  140. text = text2
  141. else:
  142. text = ""
  143. return text
  144. @register.register_method(output_names=["text"])
  145. def str_empty(text1, text2):
  146. x1, x2 = register.call_preprocess(preprocess, text1)
  147. y = register.call_servable(x1, x2)
  148. text = register.call_postprocess(str_empty_postprocess, text1, text2)
  149. return text
  150. """
  151. base.init_servable_with_servable_config(1, servable_content)
  152. return base
  153. def init_bytes_servable():
  154. base = ServingTestBase()
  155. servable_content = servable_config_import
  156. servable_content += servable_config_declare_servable
  157. servable_content += r"""
  158. def preprocess(other):
  159. return np.ones([2,2], np.float32), np.ones([2,2], np.float32)
  160. def bytes_concat_postprocess(text1, text2):
  161. text1 = bytes.decode(text1.tobytes()) # bytes decode to str
  162. text2 = bytes.decode(text2.tobytes()) # bytes decode to str
  163. return str.encode(text1 + text2) # str encode to bytes
  164. @register.register_method(output_names=["text"])
  165. def bytes_concat(text1, text2):
  166. x1, x2 = register.call_preprocess(preprocess, text1)
  167. y = register.call_servable(x1, x2)
  168. text = register.call_postprocess(bytes_concat_postprocess, text1, text2)
  169. return text
  170. def bytes_empty_postprocess(text1, text2):
  171. text1 = bytes.decode(text1.tobytes()) # bytes decode to str
  172. text2 = bytes.decode(text2.tobytes()) # bytes decode to str
  173. if len(text1) == 0:
  174. text = text2
  175. else:
  176. text = ""
  177. return str.encode(text) # str encode to bytes
  178. @register.register_method(output_names=["text"])
  179. def bytes_empty(text1, text2):
  180. x1, x2 = register.call_preprocess(preprocess, text1)
  181. y = register.call_servable(x1, x2)
  182. text = register.call_postprocess(bytes_empty_postprocess, text1, text2)
  183. return text
  184. """
  185. base.init_servable_with_servable_config(1, servable_content)
  186. return base
  187. def init_bool_int_float_servable():
  188. base = ServingTestBase()
  189. servable_content = servable_config_import
  190. servable_content += servable_config_declare_servable
  191. servable_content += r"""
  192. def preprocess(other):
  193. return np.ones([2,2], np.float32), np.ones([2,2], np.float32)
  194. def bool_postprocess(bool_val):
  195. return ~bool_val
  196. @register.register_method(output_names=["value"])
  197. def bool_not(bool_val):
  198. x1, x2 = register.call_preprocess(preprocess, bool_val)
  199. y = register.call_servable(x1, x2)
  200. value = register.call_postprocess(bool_postprocess, bool_val)
  201. return value
  202. def int_postprocess(int_val):
  203. return int_val + 1
  204. @register.register_method(output_names=["value"])
  205. def int_plus_1(int_val):
  206. x1, x2 = register.call_preprocess(preprocess, int_val)
  207. y = register.call_servable(x1, x2)
  208. value = register.call_postprocess(int_postprocess, int_val)
  209. return value
  210. def float_postprocess(float_val):
  211. value = float_val + 1
  212. if value.dtype == np.float16:
  213. value = value.astype(np.float32)
  214. return value
  215. @register.register_method(output_names=["value"])
  216. def float_plus_1(float_val):
  217. x1, x2 = register.call_preprocess(preprocess, float_val)
  218. y = register.call_servable(x1, x2)
  219. value = register.call_postprocess(float_postprocess, float_val)
  220. return value
  221. """
  222. base.init_servable_with_servable_config(1, servable_content)
  223. return base

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.