|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """MindSpore Serving Client"""
-
- import grpc
- import numpy as np
- import mindspore_serving.proto.ms_service_pb2 as ms_service_pb2
- import mindspore_serving.proto.ms_service_pb2_grpc as ms_service_pb2_grpc
-
-
- def _create_tensor(data, tensor=None):
- """Create tensor from numpy data"""
- if tensor is None:
- tensor = ms_service_pb2.Tensor()
-
- tensor.shape.dims.extend(data.shape)
- dtype_map = {
- np.bool: ms_service_pb2.MS_BOOL,
- np.int8: ms_service_pb2.MS_INT8,
- np.uint8: ms_service_pb2.MS_UINT8,
- np.int16: ms_service_pb2.MS_INT16,
- np.uint16: ms_service_pb2.MS_UINT16,
- np.int32: ms_service_pb2.MS_INT32,
- np.uint32: ms_service_pb2.MS_UINT32,
-
- np.int64: ms_service_pb2.MS_INT64,
- np.uint64: ms_service_pb2.MS_UINT64,
- np.float16: ms_service_pb2.MS_FLOAT16,
- np.float32: ms_service_pb2.MS_FLOAT32,
- np.float64: ms_service_pb2.MS_FLOAT64,
- }
- for k, v in dtype_map.items():
- if k == data.dtype:
- tensor.dtype = v
- break
- if tensor.dtype == ms_service_pb2.MS_UNKNOWN:
- raise RuntimeError("Unknown data type " + str(data.dtype))
- tensor.data = data.tobytes()
- return tensor
-
-
- def _create_scalar_tensor(vals, tensor=None):
- """Create tensor from scalar data"""
- if not isinstance(vals, (tuple, list)):
- vals = (vals,)
- return _create_tensor(np.array(vals), tensor)
-
-
- def _create_bytes_tensor(bytes_vals, tensor=None):
- """Create tensor from bytes data"""
- if tensor is None:
- tensor = ms_service_pb2.Tensor()
-
- if not isinstance(bytes_vals, (tuple, list)):
- bytes_vals = (bytes_vals,)
- tensor.shape.dims.extend([len(bytes_vals)])
- tensor.dtype = ms_service_pb2.MS_BYTES
- for item in bytes_vals:
- tensor.bytes_val.append(item)
- return tensor
-
-
- def _create_str_tensor(str_vals, tensor=None):
- """Create tensor from str data"""
- if tensor is None:
- tensor = ms_service_pb2.Tensor()
-
- if not isinstance(str_vals, (tuple, list)):
- str_vals = (str_vals,)
- tensor.shape.dims.extend([len(str_vals)])
- tensor.dtype = ms_service_pb2.MS_STRING
- for item in str_vals:
- tensor.bytes_val.append(bytes(item, encoding="utf8"))
- return tensor
-
-
- def _create_numpy_from_tensor(tensor):
- """Create numpy from protobuf tensor"""
- dtype_map = {
- ms_service_pb2.MS_BOOL: np.bool,
- ms_service_pb2.MS_INT8: np.int8,
- ms_service_pb2.MS_UINT8: np.uint8,
- ms_service_pb2.MS_INT16: np.int16,
- ms_service_pb2.MS_UINT16: np.uint16,
- ms_service_pb2.MS_INT32: np.int32,
- ms_service_pb2.MS_UINT32: np.uint32,
-
- ms_service_pb2.MS_INT64: np.int64,
- ms_service_pb2.MS_UINT64: np.uint64,
- ms_service_pb2.MS_FLOAT16: np.float16,
- ms_service_pb2.MS_FLOAT32: np.float32,
- ms_service_pb2.MS_FLOAT64: np.float64,
- }
- if tensor.dtype == ms_service_pb2.MS_STRING or tensor.dtype == ms_service_pb2.MS_BYTES:
- result = []
- for item in tensor.bytes_val:
- if tensor.dtype == ms_service_pb2.MS_STRING:
- result.append(bytes.decode(item))
- else:
- result.append(item)
- if len(result) == 1:
- return result[0]
- return result
-
- result = np.frombuffer(tensor.data, dtype_map[tensor.dtype]).reshape(tensor.shape.dims)
- return result
-
-
- def _check_str(arg_name, str_val):
- """Check whether the input parameters are reasonable str input"""
- if not isinstance(str_val, str):
- raise RuntimeError(f"Parameter '{arg_name}' should be str, but actually {type(str_val)}")
- if not str_val:
- raise RuntimeError(f"Parameter '{arg_name}' should not be empty str")
-
-
- def _check_int(arg_name, int_val, mininum=None, maximum=None):
- """Check whether the input parameters are reasonable int input"""
- if not isinstance(int_val, int):
- raise RuntimeError(f"Parameter '{arg_name}' should be int, but actually {type(int_val)}")
- if mininum is not None and int_val < mininum:
- if maximum is not None:
- raise RuntimeError(f"Parameter '{arg_name}' should be in range [{mininum},{maximum}]")
- raise RuntimeError(f"Parameter '{arg_name}' should be >= {mininum}")
- if maximum is not None and int_val > maximum:
- if mininum is not None:
- raise RuntimeError(f"Parameter '{arg_name}' should be in range [{mininum},{maximum}]")
- raise RuntimeError(f"Parameter '{arg_name}' should be <= {maximum}")
-
-
- class Client:
- """
- The Client encapsulates the serving gRPC API, which can be used to create requests,
- access serving, and parse results.
-
- Args:
- ip (str): Serving ip.
- port (int): Serving port.
- servable_name (str): The name of servable supplied by Serving.
- method_name (str): The name of method supplied by servable.
- version_number (int): The version number of servable, default 0,
- which means the maximum version number in all running versions.
- Raises:
- RuntimeError: The type or value of the parameters is invalid, or other errors happened.
-
- Examples:
- >>> from mindspore_serving.client import Client
- >>> import numpy as np
- >>> client = Client("localhost", 5500, "add", "add_cast")
- >>> instances = []
- >>> x1 = np.ones((2, 2), np.int32)
- >>> x2 = np.ones((2, 2), np.int32)
- >>> instances.append({"x1": x1, "x2": x2})
- >>> result = client.infer(instances)
- >>> print(result)
- """
-
- def __init__(self, ip, port, servable_name, method_name, version_number=0):
- _check_str("ip", ip)
- _check_int("port", port, 1, 65535)
- _check_str("servable_name", servable_name)
- _check_str("method_name", method_name)
- _check_int("version_number", version_number, 0)
-
- self.ip = ip
- self.port = port
- self.servable_name = servable_name
- self.method_name = method_name
- self.version_number = version_number
-
- channel_str = str(ip) + ":" + str(port)
- msg_bytes_size = 512 * 1024 * 1024 # 512MB
- channel = grpc.insecure_channel(channel_str,
- options=[
- ('grpc.max_send_message_length', msg_bytes_size),
- ('grpc.max_receive_message_length', msg_bytes_size),
- ])
- self.stub = ms_service_pb2_grpc.MSServiceStub(channel)
-
- def infer(self, instances):
- """
- Used to create requests, access serving, and parse results.
-
- Args:
- instances (map, tuple of map): Instance or tuple of instances, every instance item is the inputs map.
- The map key is the input name, and the value is the input value.
-
- Raises:
- RuntimeError: The type or value of the parameters is invalid, or other errors happened.
- """
- if not isinstance(instances, (tuple, list)):
- instances = (instances,)
- request = self._create_request()
- for item in instances:
- if isinstance(item, dict):
- request.instances.append(self._create_instance(**item))
- else:
- raise RuntimeError("instance should be a map")
-
- try:
- result = self.stub.Predict(request)
- return self._paser_result(result)
-
- except grpc.RpcError as e:
- print(e.details())
- status_code = e.code()
- print(status_code.name)
- print(status_code.value)
- return {"error": "Grpc Error, " + str(status_code.value)}
-
- def _create_request(self):
- """Used to create request spec."""
- request = ms_service_pb2.PredictRequest()
- request.servable_spec.name = self.servable_name
- request.servable_spec.method_name = self.method_name
- request.servable_spec.version_number = self.version_number
- return request
-
- def _create_instance(self, **kwargs):
- """Used to create gRPC instance."""
- instance = ms_service_pb2.Instance()
- for k, w in kwargs.items():
- tensor = instance.items[k]
- if isinstance(w, (np.ndarray, np.number)):
- _create_tensor(w, tensor)
- elif isinstance(w, str):
- _create_str_tensor(w, tensor)
- elif isinstance(w, (bool, int, float)):
- _create_scalar_tensor(w, tensor)
- elif isinstance(w, bytes):
- _create_bytes_tensor(w, tensor)
- else:
- raise RuntimeError("Not support value type " + str(type(w)))
- return instance
-
- def _paser_result(self, result):
- """Used to parse result."""
- error_msg_len = len(result.error_msg)
- if error_msg_len == 1:
- return {"error": bytes.decode(result.error_msg[0].error_msg)}
- ret_val = []
- instance_len = len(result.instances)
- if error_msg_len not in (0, instance_len):
- raise RuntimeError(f"error msg result size {error_msg_len} not be 0, 1 or "
- f"length of instances {instance_len}")
- for i in range(instance_len):
- instance = result.instances[i]
- if error_msg_len == 0 or result.error_msg[i].error_code == 0:
- instance_map = {}
- for k, w in instance.items.items():
- instance_map[k] = _create_numpy_from_tensor(w)
- ret_val.append(instance_map)
- else:
- ret_val.append({"error": bytes.decode(result.error_msg[i].error_msg)})
- return ret_val
|