You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

client.py 8.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import grpc
  16. import numpy as np
  17. import mindspore_serving.proto.ms_service_pb2 as ms_service_pb2
  18. import mindspore_serving.proto.ms_service_pb2_grpc as ms_service_pb2_grpc
  19. def _create_tensor(data, tensor=None):
  20. if tensor is None:
  21. tensor = ms_service_pb2.Tensor()
  22. tensor.shape.dims.extend(data.shape)
  23. dtype_map = {
  24. np.bool: ms_service_pb2.MS_BOOL,
  25. np.int8: ms_service_pb2.MS_INT8,
  26. np.uint8: ms_service_pb2.MS_UINT8,
  27. np.int16: ms_service_pb2.MS_INT16,
  28. np.uint16: ms_service_pb2.MS_UINT16,
  29. np.int32: ms_service_pb2.MS_INT32,
  30. np.uint32: ms_service_pb2.MS_UINT32,
  31. np.int64: ms_service_pb2.MS_INT64,
  32. np.uint64: ms_service_pb2.MS_UINT64,
  33. np.float16: ms_service_pb2.MS_FLOAT16,
  34. np.float32: ms_service_pb2.MS_FLOAT32,
  35. np.float64: ms_service_pb2.MS_FLOAT64,
  36. }
  37. for k, v in dtype_map.items():
  38. if k == data.dtype:
  39. tensor.dtype = v
  40. break
  41. if tensor.dtype == ms_service_pb2.MS_UNKNOWN:
  42. raise RuntimeError("Unknown data type " + str(data.dtype))
  43. tensor.data = data.tobytes()
  44. return tensor
  45. def _create_scalar_tensor(vals, tensor=None):
  46. if not isinstance(vals, (tuple, list)):
  47. vals = (vals,)
  48. return _create_tensor(np.array(vals), tensor)
  49. def _create_bytes_tensor(bytes_vals, tensor=None):
  50. if tensor is None:
  51. tensor = ms_service_pb2.Tensor()
  52. if not isinstance(bytes_vals, (tuple, list)):
  53. bytes_vals = (bytes_vals,)
  54. tensor.shape.dims.extend([len(bytes_vals)])
  55. tensor.dtype = ms_service_pb2.MS_BYTES
  56. for item in bytes_vals:
  57. tensor.bytes_val.append(item)
  58. return tensor
  59. def _create_str_tensor(str_vals, tensor=None):
  60. if tensor is None:
  61. tensor = ms_service_pb2.Tensor()
  62. if not isinstance(str_vals, (tuple, list)):
  63. str_vals = (str_vals,)
  64. tensor.shape.dims.extend([len(str_vals)])
  65. tensor.dtype = ms_service_pb2.MS_STRING
  66. for item in str_vals:
  67. tensor.bytes_val.append(bytes(item, encoding="utf8"))
  68. return tensor
  69. def _create_numpy_from_tensor(tensor):
  70. dtype_map = {
  71. ms_service_pb2.MS_BOOL: np.bool,
  72. ms_service_pb2.MS_INT8: np.int8,
  73. ms_service_pb2.MS_UINT8: np.uint8,
  74. ms_service_pb2.MS_INT16: ms_service_pb2.MS_INT16,
  75. ms_service_pb2.MS_UINT16: np.uint16,
  76. ms_service_pb2.MS_INT32: np.int32,
  77. ms_service_pb2.MS_UINT32: np.uint32,
  78. ms_service_pb2.MS_INT64: np.int64,
  79. ms_service_pb2.MS_UINT64: np.uint64,
  80. ms_service_pb2.MS_FLOAT16: np.float16,
  81. ms_service_pb2.MS_FLOAT32: np.float32,
  82. ms_service_pb2.MS_FLOAT64: np.float64,
  83. }
  84. if tensor.dtype == ms_service_pb2.MS_STRING or tensor.dtype == ms_service_pb2.MS_BYTES:
  85. result = []
  86. for item in tensor.bytes_val:
  87. if tensor.dtype == ms_service_pb2.MS_STRING:
  88. result.append(bytes.decode(item))
  89. else:
  90. result.append(item)
  91. if len(result) == 1:
  92. return result[0]
  93. return result
  94. result = np.frombuffer(tensor.data, dtype_map[tensor.dtype]).reshape(tensor.shape.dims)
  95. if not tensor.shape.dims or (len(tensor.shape.dims) == 1 and tensor.shape.dims[0] == 1):
  96. result = result.reshape((1,))[0]
  97. return result
  98. def _check_str(arg_name, str_val):
  99. """Check whether the input parameters are reasonable str input"""
  100. if not isinstance(str_val, str):
  101. raise RuntimeError(f"Parameter '{arg_name}' should be str, but actually {type(str_val)}")
  102. if not str_val:
  103. raise RuntimeError(f"Parameter '{arg_name}' should not be empty str")
  104. def _check_int(arg_name, int_val, mininum=None, maximum=None):
  105. """Check whether the input parameters are reasonable int input"""
  106. if not isinstance(int_val, int):
  107. raise RuntimeError(f"Parameter '{arg_name}' should be int, but actually {type(int_val)}")
  108. if mininum is not None and int_val < mininum:
  109. if maximum is not None:
  110. raise RuntimeError(f"Parameter '{arg_name}' should be in range [{mininum},{maximum}]")
  111. raise RuntimeError(f"Parameter '{arg_name}' should be >= {mininum}")
  112. if maximum is not None and int_val > maximum:
  113. if mininum is not None:
  114. raise RuntimeError(f"Parameter '{arg_name}' should be in range [{mininum},{maximum}]")
  115. raise RuntimeError(f"Parameter '{arg_name}' should be <= {maximum}")
  116. class Client:
  117. def __init__(self, ip, port, servable_name, method_name, version_number=0):
  118. '''
  119. Create Client connect to serving
  120. :param ip: serving ip
  121. :param port: serving port
  122. :param servable_name: the name of servable supplied by serving
  123. :param method_name: method supplied by servable
  124. :param version_number: the version number of servable, default 0.
  125. 0 meaning the maximum version number in all running versions.
  126. '''
  127. _check_str("ip", ip)
  128. _check_int("port", port, 0, 65535)
  129. _check_str("servable_name", servable_name)
  130. _check_str("method_name", method_name)
  131. _check_int("version_number", version_number, 0)
  132. self.ip = ip
  133. self.port = port
  134. self.servable_name = servable_name
  135. self.method_name = method_name
  136. self.version_number = version_number
  137. channel_str = str(ip) + ":" + str(port)
  138. channel = grpc.insecure_channel(channel_str)
  139. self.stub = ms_service_pb2_grpc.MSServiceStub(channel)
  140. def _create_request(self):
  141. request = ms_service_pb2.PredictRequest()
  142. request.servable_spec.name = self.servable_name
  143. request.servable_spec.method_name = self.method_name
  144. request.servable_spec.version_number = self.version_number
  145. return request
  146. def _create_instance(self, **kwargs):
  147. instance = ms_service_pb2.Instance()
  148. for k, w in kwargs.items():
  149. tensor = instance.items[k]
  150. if isinstance(w, (np.ndarray, np.number)):
  151. _create_tensor(w, tensor)
  152. elif isinstance(w, str):
  153. _create_str_tensor(w, tensor)
  154. elif isinstance(w, (bool, int, float)):
  155. _create_scalar_tensor(w, tensor)
  156. elif isinstance(w, bytes):
  157. _create_bytes_tensor(w, tensor)
  158. else:
  159. raise RuntimeError("Not support value type " + str(type(w)))
  160. return instance
  161. def _paser_result(self, result):
  162. error_msg_len = len(result.error_msg)
  163. if error_msg_len == 1:
  164. return {"error": bytes.decode(result.error_msg[0].error_msg)}
  165. ret_val = []
  166. instance_len = len(result.instances)
  167. if error_msg_len not in (0, instance_len):
  168. raise RuntimeError(f"error msg result size {error_msg_len} not be 0, 1 or "
  169. f"length of instances {instance_len}")
  170. for i in range(instance_len):
  171. instance = result.instances[i]
  172. if error_msg_len == 0 or result.error_msg[i].error_code == 0:
  173. instance_map = {}
  174. for k, w in instance.items.items():
  175. instance_map[k] = _create_numpy_from_tensor(w)
  176. ret_val.append(instance_map)
  177. else:
  178. ret_val.append({"error": bytes.decode(result.error_msg[i].error_msg)})
  179. return ret_val
  180. def infer(self, instances):
  181. if not isinstance(instances, (tuple, list)):
  182. instances = (instances,)
  183. request = self._create_request()
  184. for item in instances:
  185. if isinstance(item, dict):
  186. request.instances.append(self._create_instance(**item))
  187. else:
  188. raise RuntimeError("instance should be a map")
  189. try:
  190. result = self.stub.Predict(request)
  191. return self._paser_result(result)
  192. except grpc.RpcError as e:
  193. print(e.details())
  194. status_code = e.code()
  195. print(status_code.name)
  196. print(status_code.value)
  197. return {"error": status_code.value}

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.