You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

client.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """MindSpore Serving Client"""
  16. import grpc
  17. import numpy as np
  18. import mindspore_serving.proto.ms_service_pb2 as ms_service_pb2
  19. import mindspore_serving.proto.ms_service_pb2_grpc as ms_service_pb2_grpc
  20. def _create_tensor(data, tensor=None):
  21. """Create tensor from numpy data"""
  22. if tensor is None:
  23. tensor = ms_service_pb2.Tensor()
  24. tensor.shape.dims.extend(data.shape)
  25. dtype_map = {
  26. np.bool: ms_service_pb2.MS_BOOL,
  27. np.int8: ms_service_pb2.MS_INT8,
  28. np.uint8: ms_service_pb2.MS_UINT8,
  29. np.int16: ms_service_pb2.MS_INT16,
  30. np.uint16: ms_service_pb2.MS_UINT16,
  31. np.int32: ms_service_pb2.MS_INT32,
  32. np.uint32: ms_service_pb2.MS_UINT32,
  33. np.int64: ms_service_pb2.MS_INT64,
  34. np.uint64: ms_service_pb2.MS_UINT64,
  35. np.float16: ms_service_pb2.MS_FLOAT16,
  36. np.float32: ms_service_pb2.MS_FLOAT32,
  37. np.float64: ms_service_pb2.MS_FLOAT64,
  38. }
  39. for k, v in dtype_map.items():
  40. if k == data.dtype:
  41. tensor.dtype = v
  42. break
  43. if tensor.dtype == ms_service_pb2.MS_UNKNOWN:
  44. raise RuntimeError("Unknown data type " + str(data.dtype))
  45. tensor.data = data.tobytes()
  46. return tensor
  47. def _create_scalar_tensor(vals, tensor=None):
  48. """Create tensor from scalar data"""
  49. if not isinstance(vals, (tuple, list)):
  50. vals = (vals,)
  51. return _create_tensor(np.array(vals), tensor)
  52. def _create_bytes_tensor(bytes_vals, tensor=None):
  53. """Create tensor from bytes data"""
  54. if tensor is None:
  55. tensor = ms_service_pb2.Tensor()
  56. if not isinstance(bytes_vals, (tuple, list)):
  57. bytes_vals = (bytes_vals,)
  58. tensor.shape.dims.extend([len(bytes_vals)])
  59. tensor.dtype = ms_service_pb2.MS_BYTES
  60. for item in bytes_vals:
  61. tensor.bytes_val.append(item)
  62. return tensor
  63. def _create_str_tensor(str_vals, tensor=None):
  64. """Create tensor from str data"""
  65. if tensor is None:
  66. tensor = ms_service_pb2.Tensor()
  67. if not isinstance(str_vals, (tuple, list)):
  68. str_vals = (str_vals,)
  69. tensor.shape.dims.extend([len(str_vals)])
  70. tensor.dtype = ms_service_pb2.MS_STRING
  71. for item in str_vals:
  72. tensor.bytes_val.append(bytes(item, encoding="utf8"))
  73. return tensor
  74. def _create_numpy_from_tensor(tensor):
  75. """Create numpy from protobuf tensor"""
  76. dtype_map = {
  77. ms_service_pb2.MS_BOOL: np.bool,
  78. ms_service_pb2.MS_INT8: np.int8,
  79. ms_service_pb2.MS_UINT8: np.uint8,
  80. ms_service_pb2.MS_INT16: np.int16,
  81. ms_service_pb2.MS_UINT16: np.uint16,
  82. ms_service_pb2.MS_INT32: np.int32,
  83. ms_service_pb2.MS_UINT32: np.uint32,
  84. ms_service_pb2.MS_INT64: np.int64,
  85. ms_service_pb2.MS_UINT64: np.uint64,
  86. ms_service_pb2.MS_FLOAT16: np.float16,
  87. ms_service_pb2.MS_FLOAT32: np.float32,
  88. ms_service_pb2.MS_FLOAT64: np.float64,
  89. }
  90. if tensor.dtype == ms_service_pb2.MS_STRING or tensor.dtype == ms_service_pb2.MS_BYTES:
  91. result = []
  92. for item in tensor.bytes_val:
  93. if tensor.dtype == ms_service_pb2.MS_STRING:
  94. result.append(bytes.decode(item))
  95. else:
  96. result.append(item)
  97. if len(result) == 1:
  98. return result[0]
  99. return result
  100. result = np.frombuffer(tensor.data, dtype_map[tensor.dtype]).reshape(tensor.shape.dims)
  101. return result
  102. def _check_str(arg_name, str_val):
  103. """Check whether the input parameters are reasonable str input"""
  104. if not isinstance(str_val, str):
  105. raise RuntimeError(f"Parameter '{arg_name}' should be str, but actually {type(str_val)}")
  106. if not str_val:
  107. raise RuntimeError(f"Parameter '{arg_name}' should not be empty str")
  108. def _check_int(arg_name, int_val, mininum=None, maximum=None):
  109. """Check whether the input parameters are reasonable int input"""
  110. if not isinstance(int_val, int):
  111. raise RuntimeError(f"Parameter '{arg_name}' should be int, but actually {type(int_val)}")
  112. if mininum is not None and int_val < mininum:
  113. if maximum is not None:
  114. raise RuntimeError(f"Parameter '{arg_name}' should be in range [{mininum},{maximum}]")
  115. raise RuntimeError(f"Parameter '{arg_name}' should be >= {mininum}")
  116. if maximum is not None and int_val > maximum:
  117. if mininum is not None:
  118. raise RuntimeError(f"Parameter '{arg_name}' should be in range [{mininum},{maximum}]")
  119. raise RuntimeError(f"Parameter '{arg_name}' should be <= {maximum}")
  120. class Client:
  121. """
  122. The Client encapsulates the serving gRPC API, which can be used to create requests,
  123. access serving, and parse results.
  124. Args:
  125. ip (str): Serving ip.
  126. port (int): Serving port.
  127. servable_name (str): The name of servable supplied by Serving.
  128. method_name (str): The name of method supplied by servable.
  129. version_number (int): The version number of servable, default 0,
  130. which means the maximum version number in all running versions.
  131. Raises:
  132. RuntimeError: The type or value of the parameters is invalid, or other errors happened.
  133. Examples:
  134. >>> from mindspore_serving.client import Client
  135. >>> import numpy as np
  136. >>> client = Client("localhost", 5500, "add", "add_cast")
  137. >>> instances = []
  138. >>> x1 = np.ones((2, 2), np.int32)
  139. >>> x2 = np.ones((2, 2), np.int32)
  140. >>> instances.append({"x1": x1, "x2": x2})
  141. >>> result = client.infer(instances)
  142. >>> print(result)
  143. """
  144. def __init__(self, ip, port, servable_name, method_name, version_number=0):
  145. _check_str("ip", ip)
  146. _check_int("port", port, 1, 65535)
  147. _check_str("servable_name", servable_name)
  148. _check_str("method_name", method_name)
  149. _check_int("version_number", version_number, 0)
  150. self.ip = ip
  151. self.port = port
  152. self.servable_name = servable_name
  153. self.method_name = method_name
  154. self.version_number = version_number
  155. channel_str = str(ip) + ":" + str(port)
  156. msg_bytes_size = 512 * 1024 * 1024 # 512MB
  157. channel = grpc.insecure_channel(channel_str,
  158. options=[
  159. ('grpc.max_send_message_length', msg_bytes_size),
  160. ('grpc.max_receive_message_length', msg_bytes_size),
  161. ])
  162. self.stub = ms_service_pb2_grpc.MSServiceStub(channel)
  163. def infer(self, instances):
  164. """
  165. Used to create requests, access serving, and parse results.
  166. Args:
  167. instances (map, tuple of map): Instance or tuple of instances, every instance item is the inputs map.
  168. The map key is the input name, and the value is the input value.
  169. Raises:
  170. RuntimeError: The type or value of the parameters is invalid, or other errors happened.
  171. """
  172. if not isinstance(instances, (tuple, list)):
  173. instances = (instances,)
  174. request = self._create_request()
  175. for item in instances:
  176. if isinstance(item, dict):
  177. request.instances.append(self._create_instance(**item))
  178. else:
  179. raise RuntimeError("instance should be a map")
  180. try:
  181. result = self.stub.Predict(request)
  182. return self._paser_result(result)
  183. except grpc.RpcError as e:
  184. print(e.details())
  185. status_code = e.code()
  186. print(status_code.name)
  187. print(status_code.value)
  188. return {"error": "Grpc Error, " + str(status_code.value)}
  189. def _create_request(self):
  190. """Used to create request spec."""
  191. request = ms_service_pb2.PredictRequest()
  192. request.servable_spec.name = self.servable_name
  193. request.servable_spec.method_name = self.method_name
  194. request.servable_spec.version_number = self.version_number
  195. return request
  196. def _create_instance(self, **kwargs):
  197. """Used to create gRPC instance."""
  198. instance = ms_service_pb2.Instance()
  199. for k, w in kwargs.items():
  200. tensor = instance.items[k]
  201. if isinstance(w, (np.ndarray, np.number)):
  202. _create_tensor(w, tensor)
  203. elif isinstance(w, str):
  204. _create_str_tensor(w, tensor)
  205. elif isinstance(w, (bool, int, float)):
  206. _create_scalar_tensor(w, tensor)
  207. elif isinstance(w, bytes):
  208. _create_bytes_tensor(w, tensor)
  209. else:
  210. raise RuntimeError("Not support value type " + str(type(w)))
  211. return instance
  212. def _paser_result(self, result):
  213. """Used to parse result."""
  214. error_msg_len = len(result.error_msg)
  215. if error_msg_len == 1:
  216. return {"error": bytes.decode(result.error_msg[0].error_msg)}
  217. ret_val = []
  218. instance_len = len(result.instances)
  219. if error_msg_len not in (0, instance_len):
  220. raise RuntimeError(f"error msg result size {error_msg_len} not be 0, 1 or "
  221. f"length of instances {instance_len}")
  222. for i in range(instance_len):
  223. instance = result.instances[i]
  224. if error_msg_len == 0 or result.error_msg[i].error_code == 0:
  225. instance_map = {}
  226. for k, w in instance.items.items():
  227. instance_map[k] = _create_numpy_from_tensor(w)
  228. ret_val.append(instance_map)
  229. else:
  230. ret_val.append({"error": bytes.decode(result.error_msg[i].error_msg)})
  231. return ret_val

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.