You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

client.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import grpc
  16. import numpy as np
  17. import mindspore_serving.proto.ms_service_pb2 as ms_service_pb2
  18. import mindspore_serving.proto.ms_service_pb2_grpc as ms_service_pb2_grpc
  19. def _create_tensor(data, tensor=None):
  20. if tensor is None:
  21. tensor = ms_service_pb2.Tensor()
  22. tensor.shape.dims.extend(data.shape)
  23. dtype_map = {
  24. np.bool: ms_service_pb2.MS_BOOL,
  25. np.int8: ms_service_pb2.MS_INT8,
  26. np.uint8: ms_service_pb2.MS_UINT8,
  27. np.int16: ms_service_pb2.MS_INT16,
  28. np.uint16: ms_service_pb2.MS_UINT16,
  29. np.int32: ms_service_pb2.MS_INT32,
  30. np.uint32: ms_service_pb2.MS_UINT32,
  31. np.int64: ms_service_pb2.MS_INT64,
  32. np.uint64: ms_service_pb2.MS_UINT64,
  33. np.float16: ms_service_pb2.MS_FLOAT16,
  34. np.float32: ms_service_pb2.MS_FLOAT32,
  35. np.float64: ms_service_pb2.MS_FLOAT64,
  36. }
  37. for k, v in dtype_map.items():
  38. if k == data.dtype:
  39. tensor.dtype = v
  40. break
  41. if tensor.dtype == ms_service_pb2.MS_UNKNOWN:
  42. raise RuntimeError("Unknown data type", data.dtype)
  43. tensor.data = data.tobytes()
  44. return tensor
  45. def _create_scalar_tensor(vals, tensor=None):
  46. if not isinstance(vals, (tuple, list)):
  47. vals = (vals,)
  48. return _create_tensor(np.array(vals), tensor)
  49. def _create_bytes_tensor(bytes_vals, tensor=None):
  50. if tensor is None:
  51. tensor = ms_service_pb2.Tensor()
  52. if not isinstance(bytes_vals, (tuple, list)):
  53. bytes_vals = (bytes_vals,)
  54. tensor.shape.dims.extend([len(bytes_vals)])
  55. tensor.dtype = ms_service_pb2.MS_BYTES
  56. for item in bytes_vals:
  57. tensor.bytes_val.append(item)
  58. return tensor
  59. def _create_str_tensor(str_vals, tensor=None):
  60. if tensor is None:
  61. tensor = ms_service_pb2.Tensor()
  62. if not isinstance(str_vals, (tuple, list)):
  63. str_vals = (str_vals,)
  64. tensor.shape.dims.extend([len(str_vals)])
  65. tensor.dtype = ms_service_pb2.MS_STRING
  66. for item in str_vals:
  67. tensor.bytes_val.append(bytes(item, encoding="utf8"))
  68. return tensor
  69. def _create_numpy_from_tensor(tensor):
  70. dtype_map = {
  71. ms_service_pb2.MS_BOOL: np.bool,
  72. ms_service_pb2.MS_INT8: np.int8,
  73. ms_service_pb2.MS_UINT8: np.uint8,
  74. ms_service_pb2.MS_INT16: ms_service_pb2.MS_INT16,
  75. ms_service_pb2.MS_UINT16: np.uint16,
  76. ms_service_pb2.MS_INT32: np.int32,
  77. ms_service_pb2.MS_UINT32: np.uint32,
  78. ms_service_pb2.MS_INT64: np.int64,
  79. ms_service_pb2.MS_UINT64: np.uint64,
  80. ms_service_pb2.MS_FLOAT16: np.float16,
  81. ms_service_pb2.MS_FLOAT32: np.float32,
  82. ms_service_pb2.MS_FLOAT64: np.float64,
  83. }
  84. if tensor.dtype == ms_service_pb2.MS_STRING or tensor.dtype == ms_service_pb2.MS_BYTES:
  85. result = []
  86. for item in tensor.bytes_val:
  87. if tensor.dtype == ms_service_pb2.MS_STRING:
  88. result.append(bytes.decode(item))
  89. else:
  90. result.append(item)
  91. if len(result) == 1:
  92. return result[0]
  93. return result
  94. result = np.frombuffer(tensor.data, dtype_map[tensor.dtype]).reshape(tensor.shape.dims)
  95. if not tensor.shape.dims or (len(tensor.shape.dims) == 1 and tensor.shape.dims[0] == 1):
  96. result = result.reshape((1,))[0]
  97. return result
  98. class Client:
  99. def __init__(self, ip, port, servable_name, method_name, version_number=None):
  100. '''
  101. Create Client connect to rving
  102. :param ip: serving ip
  103. :param port: serving port
  104. :param servable_name: the name of servable supplied by serving
  105. :param method_name: method supplied by servable
  106. :param version_number: the version number of servable, default None.
  107. None meaning the maximum version number in all running versions.
  108. '''
  109. self.ip = ip
  110. self.port = port
  111. self.servable_name = servable_name
  112. self.method_name = method_name
  113. self.version_number = version_number
  114. channel_str = str(ip) + ":" + str(port)
  115. channel = grpc.insecure_channel(channel_str)
  116. self.stub = ms_service_pb2_grpc.MSServiceStub(channel)
  117. def _create_request(self):
  118. request = ms_service_pb2.PredictRequest()
  119. request.servable_spec.name = self.servable_name
  120. request.servable_spec.method_name = self.method_name
  121. if self.version_number is not None:
  122. request.servable_spec.version_number = self.version_number
  123. return request
  124. def _create_instance(self, **kwargs):
  125. instance = ms_service_pb2.Instance()
  126. for k, w in kwargs.items():
  127. tensor = instance.items[k]
  128. if isinstance(w, (np.ndarray, np.number)):
  129. _create_tensor(w, tensor)
  130. elif isinstance(w, str):
  131. _create_str_tensor(w, tensor)
  132. elif isinstance(w, (bool, int, float)):
  133. _create_scalar_tensor(w, tensor)
  134. elif isinstance(w, bytes):
  135. _create_bytes_tensor(w, tensor)
  136. else:
  137. raise RuntimeError("Not support value type " + str(type(w)))
  138. return instance
  139. def _paser_result(self, result):
  140. error_msg_len = len(result.error_msg)
  141. if error_msg_len == 1:
  142. return {"error": bytes.decode(result.error_msg[0].error_msg)}
  143. ret_val = []
  144. instance_len = len(result.instances)
  145. if error_msg_len not in (0, instance_len):
  146. raise RuntimeError(
  147. "error msg result size " + error_msg_len + " not be 0,1 or length of instances " + str(
  148. instance_len))
  149. for i in range(instance_len):
  150. instance = result.instances[i]
  151. if error_msg_len == 0 or result.error_msg[i].error_code == 0:
  152. instance_map = {}
  153. for k, w in instance.items.items():
  154. instance_map[k] = _create_numpy_from_tensor(w)
  155. ret_val.append(instance_map)
  156. else:
  157. ret_val.append({"error": bytes.decode(result.error_msg[i].error_msg)})
  158. return ret_val
  159. def infer(self, instances):
  160. if not isinstance(instances, (tuple, list)):
  161. instances = (instances,)
  162. request = self._create_request()
  163. for item in instances:
  164. if isinstance(item, dict):
  165. request.instances.append(self._create_instance(**item))
  166. else:
  167. raise RuntimeError("instance should be a map")
  168. try:
  169. result = self.stub.Predict(request)
  170. return self._paser_result(result)
  171. except grpc.RpcError as e:
  172. print(e.details())
  173. status_code = e.code()
  174. print(status_code.name)
  175. print(status_code.value)
  176. return {"error": status_code.value}

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.