You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

client.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """MindSpore Serving Client"""
  16. import grpc
  17. import numpy as np
  18. import mindspore_serving.proto.ms_service_pb2 as ms_service_pb2
  19. import mindspore_serving.proto.ms_service_pb2_grpc as ms_service_pb2_grpc
  20. def _create_tensor(data, tensor=None):
  21. """Create tensor from numpy data"""
  22. if tensor is None:
  23. tensor = ms_service_pb2.Tensor()
  24. tensor.shape.dims.extend(data.shape)
  25. dtype_map = {
  26. np.bool: ms_service_pb2.MS_BOOL,
  27. np.int8: ms_service_pb2.MS_INT8,
  28. np.uint8: ms_service_pb2.MS_UINT8,
  29. np.int16: ms_service_pb2.MS_INT16,
  30. np.uint16: ms_service_pb2.MS_UINT16,
  31. np.int32: ms_service_pb2.MS_INT32,
  32. np.uint32: ms_service_pb2.MS_UINT32,
  33. np.int64: ms_service_pb2.MS_INT64,
  34. np.uint64: ms_service_pb2.MS_UINT64,
  35. np.float16: ms_service_pb2.MS_FLOAT16,
  36. np.float32: ms_service_pb2.MS_FLOAT32,
  37. np.float64: ms_service_pb2.MS_FLOAT64,
  38. }
  39. for k, v in dtype_map.items():
  40. if k == data.dtype:
  41. tensor.dtype = v
  42. break
  43. if tensor.dtype == ms_service_pb2.MS_UNKNOWN:
  44. raise RuntimeError("Unknown data type " + str(data.dtype))
  45. tensor.data = data.tobytes()
  46. return tensor
  47. def _create_scalar_tensor(vals, tensor=None):
  48. """Create tensor from scalar data"""
  49. if not isinstance(vals, (tuple, list)):
  50. vals = (vals,)
  51. return _create_tensor(np.array(vals), tensor)
  52. def _create_bytes_tensor(bytes_vals, tensor=None):
  53. """Create tensor from bytes data"""
  54. if tensor is None:
  55. tensor = ms_service_pb2.Tensor()
  56. if not isinstance(bytes_vals, (tuple, list)):
  57. bytes_vals = (bytes_vals,)
  58. tensor.shape.dims.extend([len(bytes_vals)])
  59. tensor.dtype = ms_service_pb2.MS_BYTES
  60. for item in bytes_vals:
  61. tensor.bytes_val.append(item)
  62. return tensor
  63. def _create_str_tensor(str_vals, tensor=None):
  64. """Create tensor from str data"""
  65. if tensor is None:
  66. tensor = ms_service_pb2.Tensor()
  67. if not isinstance(str_vals, (tuple, list)):
  68. str_vals = (str_vals,)
  69. tensor.shape.dims.extend([len(str_vals)])
  70. tensor.dtype = ms_service_pb2.MS_STRING
  71. for item in str_vals:
  72. tensor.bytes_val.append(bytes(item, encoding="utf8"))
  73. return tensor
  74. def _create_numpy_from_tensor(tensor):
  75. """Create numpy from protobuf tensor"""
  76. dtype_map = {
  77. ms_service_pb2.MS_BOOL: np.bool,
  78. ms_service_pb2.MS_INT8: np.int8,
  79. ms_service_pb2.MS_UINT8: np.uint8,
  80. ms_service_pb2.MS_INT16: ms_service_pb2.MS_INT16,
  81. ms_service_pb2.MS_UINT16: np.uint16,
  82. ms_service_pb2.MS_INT32: np.int32,
  83. ms_service_pb2.MS_UINT32: np.uint32,
  84. ms_service_pb2.MS_INT64: np.int64,
  85. ms_service_pb2.MS_UINT64: np.uint64,
  86. ms_service_pb2.MS_FLOAT16: np.float16,
  87. ms_service_pb2.MS_FLOAT32: np.float32,
  88. ms_service_pb2.MS_FLOAT64: np.float64,
  89. }
  90. if tensor.dtype == ms_service_pb2.MS_STRING or tensor.dtype == ms_service_pb2.MS_BYTES:
  91. result = []
  92. for item in tensor.bytes_val:
  93. if tensor.dtype == ms_service_pb2.MS_STRING:
  94. result.append(bytes.decode(item))
  95. else:
  96. result.append(item)
  97. if len(result) == 1:
  98. return result[0]
  99. return result
  100. result = np.frombuffer(tensor.data, dtype_map[tensor.dtype]).reshape(tensor.shape.dims)
  101. if not tensor.shape.dims or (len(tensor.shape.dims) == 1 and tensor.shape.dims[0] == 1):
  102. result = result.reshape((1,))[0]
  103. return result
  104. def _check_str(arg_name, str_val):
  105. """Check whether the input parameters are reasonable str input"""
  106. if not isinstance(str_val, str):
  107. raise RuntimeError(f"Parameter '{arg_name}' should be str, but actually {type(str_val)}")
  108. if not str_val:
  109. raise RuntimeError(f"Parameter '{arg_name}' should not be empty str")
  110. def _check_int(arg_name, int_val, mininum=None, maximum=None):
  111. """Check whether the input parameters are reasonable int input"""
  112. if not isinstance(int_val, int):
  113. raise RuntimeError(f"Parameter '{arg_name}' should be int, but actually {type(int_val)}")
  114. if mininum is not None and int_val < mininum:
  115. if maximum is not None:
  116. raise RuntimeError(f"Parameter '{arg_name}' should be in range [{mininum},{maximum}]")
  117. raise RuntimeError(f"Parameter '{arg_name}' should be >= {mininum}")
  118. if maximum is not None and int_val > maximum:
  119. if mininum is not None:
  120. raise RuntimeError(f"Parameter '{arg_name}' should be in range [{mininum},{maximum}]")
  121. raise RuntimeError(f"Parameter '{arg_name}' should be <= {maximum}")
  122. class Client:
  123. """
  124. The Client encapsulates the serving gRPC API, which can be used to create requests,
  125. access serving, and parse results.
  126. Args:
  127. ip(str): Serving ip.
  128. port(int): Serving port.
  129. servable_name(str): The name of servable supplied by Serving.
  130. method_name(str): The name of method supplied by servable.
  131. version_number(int): The version number of servable, default 0,
  132. 0 meaning the maximum version number in all running versions.
  133. max_msg_mb_size(int): The maximum acceptable gRPC message size in megabytes(MB), default 512,
  134. value range [1, 512].
  135. Raises:
  136. RuntimeError: The type or value of the parameters is invalid, or other error happened.
  137. Examples:
  138. >>> from mindspore_serving.client import Client
  139. >>> import numpy as np
  140. >>> client = Client("localhost", 5500, "add", "add_cast")
  141. >>> instances = []
  142. >>> x1 = np.ones((2, 2), np.int32)
  143. >>> x2 = np.ones((2, 2), np.int32)
  144. >>> instances.append({"x1": x1, "x2": x2})
  145. >>> result = client.infer(instances)
  146. >>> print(result)
  147. """
  148. def __init__(self, ip, port, servable_name, method_name, version_number=0, max_msg_mb_size=512):
  149. _check_str("ip", ip)
  150. _check_int("port", port, 0, 65535)
  151. _check_str("servable_name", servable_name)
  152. _check_str("method_name", method_name)
  153. _check_int("version_number", version_number, 0)
  154. _check_int("max_msg_mb_size", max_msg_mb_size, 1, 512)
  155. self.ip = ip
  156. self.port = port
  157. self.servable_name = servable_name
  158. self.method_name = method_name
  159. self.version_number = version_number
  160. channel_str = str(ip) + ":" + str(port)
  161. msg_bytes_size = max_msg_mb_size * 1024 * 1024
  162. channel = grpc.insecure_channel(channel_str,
  163. options=[
  164. ('grpc.max_send_message_length', msg_bytes_size),
  165. ('grpc.max_receive_message_length', msg_bytes_size),
  166. ])
  167. self.stub = ms_service_pb2_grpc.MSServiceStub(channel)
  168. def infer(self, instances):
  169. """
  170. Used to create requests, access serving, and parse results.
  171. Args:
  172. instances(map, tuple of map): Instance or tuple of instance, every instance item is the inputs map.
  173. The map key is the input name, and the value is the input value.
  174. Raises:
  175. RuntimeError: The type or value of the parameters is invalid, or other error happened.
  176. """
  177. if not isinstance(instances, (tuple, list)):
  178. instances = (instances,)
  179. request = self._create_request()
  180. for item in instances:
  181. if isinstance(item, dict):
  182. request.instances.append(self._create_instance(**item))
  183. else:
  184. raise RuntimeError("instance should be a map")
  185. try:
  186. result = self.stub.Predict(request)
  187. return self._paser_result(result)
  188. except grpc.RpcError as e:
  189. print(e.details())
  190. status_code = e.code()
  191. print(status_code.name)
  192. print(status_code.value)
  193. return {"error": "Grpc Error, " + str(status_code.value)}
  194. def _create_request(self):
  195. """Used to create request spec."""
  196. request = ms_service_pb2.PredictRequest()
  197. request.servable_spec.name = self.servable_name
  198. request.servable_spec.method_name = self.method_name
  199. request.servable_spec.version_number = self.version_number
  200. return request
  201. def _create_instance(self, **kwargs):
  202. """Used to create gRPC instance."""
  203. instance = ms_service_pb2.Instance()
  204. for k, w in kwargs.items():
  205. tensor = instance.items[k]
  206. if isinstance(w, (np.ndarray, np.number)):
  207. _create_tensor(w, tensor)
  208. elif isinstance(w, str):
  209. _create_str_tensor(w, tensor)
  210. elif isinstance(w, (bool, int, float)):
  211. _create_scalar_tensor(w, tensor)
  212. elif isinstance(w, bytes):
  213. _create_bytes_tensor(w, tensor)
  214. else:
  215. raise RuntimeError("Not support value type " + str(type(w)))
  216. return instance
  217. def _paser_result(self, result):
  218. """Used to parse result."""
  219. error_msg_len = len(result.error_msg)
  220. if error_msg_len == 1:
  221. return {"error": bytes.decode(result.error_msg[0].error_msg)}
  222. ret_val = []
  223. instance_len = len(result.instances)
  224. if error_msg_len not in (0, instance_len):
  225. raise RuntimeError(f"error msg result size {error_msg_len} not be 0, 1 or "
  226. f"length of instances {instance_len}")
  227. for i in range(instance_len):
  228. instance = result.instances[i]
  229. if error_msg_len == 0 or result.error_msg[i].error_code == 0:
  230. instance_map = {}
  231. for k, w in instance.items.items():
  232. instance_map[k] = _create_numpy_from_tensor(w)
  233. ret_val.append(instance_map)
  234. else:
  235. ret_val.append({"error": bytes.decode(result.error_msg[i].error_msg)})
  236. return ret_val

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.