You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

client.py 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """MindSpore Serving Client"""
  16. import grpc
  17. import numpy as np
  18. import mindspore_serving.proto.ms_service_pb2 as ms_service_pb2
  19. import mindspore_serving.proto.ms_service_pb2_grpc as ms_service_pb2_grpc
  20. def _create_tensor(data, tensor=None):
  21. """Create tensor from numpy data"""
  22. if tensor is None:
  23. tensor = ms_service_pb2.Tensor()
  24. tensor.shape.dims.extend(data.shape)
  25. dtype_map = {
  26. np.bool: ms_service_pb2.MS_BOOL,
  27. np.int8: ms_service_pb2.MS_INT8,
  28. np.uint8: ms_service_pb2.MS_UINT8,
  29. np.int16: ms_service_pb2.MS_INT16,
  30. np.uint16: ms_service_pb2.MS_UINT16,
  31. np.int32: ms_service_pb2.MS_INT32,
  32. np.uint32: ms_service_pb2.MS_UINT32,
  33. np.int64: ms_service_pb2.MS_INT64,
  34. np.uint64: ms_service_pb2.MS_UINT64,
  35. np.float16: ms_service_pb2.MS_FLOAT16,
  36. np.float32: ms_service_pb2.MS_FLOAT32,
  37. np.float64: ms_service_pb2.MS_FLOAT64,
  38. }
  39. for k, v in dtype_map.items():
  40. if k == data.dtype:
  41. tensor.dtype = v
  42. break
  43. if tensor.dtype == ms_service_pb2.MS_UNKNOWN:
  44. raise RuntimeError("Unknown data type " + str(data.dtype))
  45. tensor.data = data.tobytes()
  46. return tensor
  47. def _create_scalar_tensor(vals, tensor=None):
  48. """Create tensor from scalar data"""
  49. if not isinstance(vals, (tuple, list)):
  50. vals = (vals,)
  51. return _create_tensor(np.array(vals), tensor)
  52. def _create_bytes_tensor(bytes_vals, tensor=None):
  53. """Create tensor from bytes data"""
  54. if tensor is None:
  55. tensor = ms_service_pb2.Tensor()
  56. if not isinstance(bytes_vals, (tuple, list)):
  57. bytes_vals = (bytes_vals,)
  58. tensor.shape.dims.extend([len(bytes_vals)])
  59. tensor.dtype = ms_service_pb2.MS_BYTES
  60. for item in bytes_vals:
  61. tensor.bytes_val.append(item)
  62. return tensor
  63. def _create_str_tensor(str_vals, tensor=None):
  64. """Create tensor from str data"""
  65. if tensor is None:
  66. tensor = ms_service_pb2.Tensor()
  67. if not isinstance(str_vals, (tuple, list)):
  68. str_vals = (str_vals,)
  69. tensor.shape.dims.extend([len(str_vals)])
  70. tensor.dtype = ms_service_pb2.MS_STRING
  71. for item in str_vals:
  72. tensor.bytes_val.append(bytes(item, encoding="utf8"))
  73. return tensor
  74. def _create_numpy_from_tensor(tensor):
  75. """Create numpy from protobuf tensor"""
  76. dtype_map = {
  77. ms_service_pb2.MS_BOOL: np.bool,
  78. ms_service_pb2.MS_INT8: np.int8,
  79. ms_service_pb2.MS_UINT8: np.uint8,
  80. ms_service_pb2.MS_INT16: np.int16,
  81. ms_service_pb2.MS_UINT16: np.uint16,
  82. ms_service_pb2.MS_INT32: np.int32,
  83. ms_service_pb2.MS_UINT32: np.uint32,
  84. ms_service_pb2.MS_INT64: np.int64,
  85. ms_service_pb2.MS_UINT64: np.uint64,
  86. ms_service_pb2.MS_FLOAT16: np.float16,
  87. ms_service_pb2.MS_FLOAT32: np.float32,
  88. ms_service_pb2.MS_FLOAT64: np.float64,
  89. }
  90. if tensor.dtype == ms_service_pb2.MS_STRING or tensor.dtype == ms_service_pb2.MS_BYTES:
  91. result = []
  92. for item in tensor.bytes_val:
  93. if tensor.dtype == ms_service_pb2.MS_STRING:
  94. result.append(bytes.decode(item))
  95. else:
  96. result.append(item)
  97. if len(result) == 1:
  98. return result[0]
  99. return result
  100. result = np.frombuffer(tensor.data, dtype_map[tensor.dtype]).reshape(tensor.shape.dims)
  101. return result
  102. def _check_str(arg_name, str_val):
  103. """Check whether the input parameters are reasonable str input"""
  104. if not isinstance(str_val, str):
  105. raise RuntimeError(f"Parameter '{arg_name}' should be str, but actually {type(str_val)}")
  106. if not str_val:
  107. raise RuntimeError(f"Parameter '{arg_name}' should not be empty str")
  108. def _check_int(arg_name, int_val, minimum=None, maximum=None):
  109. """Check whether the input parameters are reasonable int input"""
  110. if not isinstance(int_val, int):
  111. raise RuntimeError(f"Parameter '{arg_name}' should be int, but actually {type(int_val)}")
  112. if minimum is not None and int_val < minimum:
  113. if maximum is not None:
  114. raise RuntimeError(f"Parameter '{arg_name}' should be in range [{minimum},{maximum}]")
  115. raise RuntimeError(f"Parameter '{arg_name}' should be >= {minimum}")
  116. if maximum is not None and int_val > maximum:
  117. if minimum is not None:
  118. raise RuntimeError(f"Parameter '{arg_name}' should be in range [{minimum},{maximum}]")
  119. raise RuntimeError(f"Parameter '{arg_name}' should be <= {maximum}")
  120. class Client:
  121. """
  122. The Client encapsulates the serving gRPC API, which can be used to create requests,
  123. access serving, and parse results.
  124. Args:
  125. ip (str): Serving ip.
  126. port (int): Serving port.
  127. servable_name (str): The name of servable supplied by Serving.
  128. method_name (str): The name of method supplied by servable.
  129. version_number (int): The version number of servable, default 0,
  130. which means the maximum version number in all running versions.
  131. Raises:
  132. RuntimeError: The type or value of the parameters is invalid, or other errors happened.
  133. Examples:
  134. >>> from mindspore_serving.client import Client
  135. >>> import numpy as np
  136. >>> client = Client("localhost", 5500, "add", "add_cast")
  137. >>> instances = []
  138. >>> x1 = np.ones((2, 2), np.int32)
  139. >>> x2 = np.ones((2, 2), np.int32)
  140. >>> instances.append({"x1": x1, "x2": x2})
  141. >>> result = client.infer(instances)
  142. >>> print(result)
  143. """
  144. def __init__(self, ip, port, servable_name, method_name, version_number=0):
  145. _check_str("ip", ip)
  146. _check_int("port", port, 1, 65535)
  147. _check_str("servable_name", servable_name)
  148. _check_str("method_name", method_name)
  149. _check_int("version_number", version_number, 0)
  150. self.ip = ip
  151. self.port = port
  152. self.servable_name = servable_name
  153. self.method_name = method_name
  154. self.version_number = version_number
  155. channel_str = str(ip) + ":" + str(port)
  156. msg_bytes_size = 512 * 1024 * 1024 # 512MB
  157. channel = grpc.insecure_channel(channel_str,
  158. options=[
  159. ('grpc.max_send_message_length', msg_bytes_size),
  160. ('grpc.max_receive_message_length', msg_bytes_size),
  161. ])
  162. self.stub = ms_service_pb2_grpc.MSServiceStub(channel)
  163. def infer(self, instances):
  164. """
  165. Used to create requests, access serving service, and parse and return results.
  166. Args:
  167. instances (map, tuple of map): Instance or tuple of instances, every instance item is the inputs map.
  168. The map key is the input name, and the value is the input value, the type of value can be python int,
  169. float, bool, str, bytes, numpy number, or numpy array object.
  170. Raises:
  171. RuntimeError: The type or value of the parameters is invalid, or other errors happened.
  172. Examples:
  173. >>> from mindspore_serving.client import Client
  174. >>> import numpy as np
  175. >>> client = Client("localhost", 5500, "add", "add_cast")
  176. >>> instances = []
  177. >>> x1 = np.ones((2, 2), np.int32)
  178. >>> x2 = np.ones((2, 2), np.int32)
  179. >>> instances.append({"x1": x1, "x2": x2})
  180. >>> result = client.infer(instances)
  181. >>> print(result)
  182. """
  183. request = self._create_request(instances)
  184. try:
  185. result = self.stub.Predict(request)
  186. return self._paser_result(result)
  187. except grpc.RpcError as e:
  188. print(e.details())
  189. status_code = e.code()
  190. print(status_code.name)
  191. print(status_code.value)
  192. return {"error": "Grpc Error, " + str(status_code.value)}
  193. def infer_async(self, instances):
  194. """
  195. Used to create requests, async access serving.
  196. Args:
  197. instances (map, tuple of map): Instance or tuple of instances, every instance item is the inputs map.
  198. The map key is the input name, and the value is the input value.
  199. Raises:
  200. RuntimeError: The type or value of the parameters is invalid, or other errors happened.
  201. Examples:
  202. >>> from mindspore_serving.client import Client
  203. >>> import numpy as np
  204. >>> client = Client("localhost", 5500, "add", "add_cast")
  205. >>> instances = []
  206. >>> x1 = np.ones((2, 2), np.int32)
  207. >>> x2 = np.ones((2, 2), np.int32)
  208. >>> instances.append({"x1": x1, "x2": x2})
  209. >>> result_future = client.infer_async(instances)
  210. >>> result = result_future.result()
  211. >>> print(result)
  212. """
  213. request = self._create_request(instances)
  214. try:
  215. result_future = self.stub.Predict.future(request)
  216. return ClientGrpcAsyncResult(result_future)
  217. except grpc.RpcError as e:
  218. print(e.details())
  219. status_code = e.code()
  220. print(status_code.name)
  221. print(status_code.value)
  222. return ClientGrpcAsyncError({"error": "Grpc Error, " + str(status_code.value)})
  223. def _create_request(self, instances):
  224. """Used to create request spec."""
  225. if not isinstance(instances, (tuple, list)):
  226. instances = (instances,)
  227. request = ms_service_pb2.PredictRequest()
  228. request.servable_spec.name = self.servable_name
  229. request.servable_spec.method_name = self.method_name
  230. request.servable_spec.version_number = self.version_number
  231. for item in instances:
  232. if isinstance(item, dict):
  233. request.instances.append(self._create_instance(**item))
  234. else:
  235. raise RuntimeError("instance should be a map")
  236. return request
  237. @staticmethod
  238. def _create_instance(**kwargs):
  239. """Used to create gRPC instance."""
  240. instance = ms_service_pb2.Instance()
  241. for k, w in kwargs.items():
  242. tensor = instance.items[k]
  243. if isinstance(w, (np.ndarray, np.number)):
  244. _create_tensor(w, tensor)
  245. elif isinstance(w, str):
  246. _create_str_tensor(w, tensor)
  247. elif isinstance(w, (bool, int, float)):
  248. _create_scalar_tensor(w, tensor)
  249. elif isinstance(w, bytes):
  250. _create_bytes_tensor(w, tensor)
  251. else:
  252. raise RuntimeError("Not support value type " + str(type(w)))
  253. return instance
  254. @staticmethod
  255. def _paser_result(result):
  256. """Used to parse result."""
  257. error_msg_len = len(result.error_msg)
  258. if error_msg_len == 1:
  259. return {"error": bytes.decode(result.error_msg[0].error_msg)}
  260. ret_val = []
  261. instance_len = len(result.instances)
  262. if error_msg_len not in (0, instance_len):
  263. raise RuntimeError(f"error msg result size {error_msg_len} not be 0, 1 or "
  264. f"length of instances {instance_len}")
  265. for i in range(instance_len):
  266. instance = result.instances[i]
  267. if error_msg_len == 0 or result.error_msg[i].error_code == 0:
  268. instance_map = {}
  269. for k, w in instance.items.items():
  270. instance_map[k] = _create_numpy_from_tensor(w)
  271. ret_val.append(instance_map)
  272. else:
  273. ret_val.append({"error": bytes.decode(result.error_msg[i].error_msg)})
  274. return ret_val
  275. class ClientGrpcAsyncResult:
  276. """
  277. When Client.infer_async invoke successfully, a ClientGrpcAsyncResult object is returned.
  278. Examples:
  279. >>> from mindspore_serving.client import Client
  280. >>> import numpy as np
  281. >>> client = Client("localhost", 5500, "add", "add_cast")
  282. >>> instances = []
  283. >>> x1 = np.ones((2, 2), np.int32)
  284. >>> x2 = np.ones((2, 2), np.int32)
  285. >>> instances.append({"x1": x1, "x2": x2})
  286. >>> result_future = client.infer_async(instances)
  287. >>> result = result_future.result()
  288. >>> print(result)
  289. """
  290. def __init__(self, result_future):
  291. self.result_future = result_future
  292. def result(self):
  293. """Wait and get result of inference result, the gRPC message will be parse to tuple of instances result.
  294. Every instance result is dict, and value could be numpy array/number, str or bytes according gRPC Tensor
  295. data type.
  296. """
  297. result = self.result_future.result()
  298. # pylint: disable=protected-access
  299. result = Client._paser_result(result)
  300. return result
  301. class ClientGrpcAsyncError:
  302. """When gRPC failed happened when calling Client.infer_async, a ClientGrpcAsyncError object is returned.
  303. """
  304. def __init__(self, result_error):
  305. self.result_error = result_error
  306. def result(self):
  307. """Get gRPC error message.
  308. """
  309. return self.result_error

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.