|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import grpc
- import numpy as np
- import mindspore_serving.proto.ms_service_pb2 as ms_service_pb2
- import mindspore_serving.proto.ms_service_pb2_grpc as ms_service_pb2_grpc
-
-
- def _create_tensor(data, tensor=None):
- if tensor is None:
- tensor = ms_service_pb2.Tensor()
-
- tensor.shape.dims.extend(data.shape)
- dtype_map = {
- np.bool: ms_service_pb2.MS_BOOL,
- np.int8: ms_service_pb2.MS_INT8,
- np.uint8: ms_service_pb2.MS_UINT8,
- np.int16: ms_service_pb2.MS_INT16,
- np.uint16: ms_service_pb2.MS_UINT16,
- np.int32: ms_service_pb2.MS_INT32,
- np.uint32: ms_service_pb2.MS_UINT32,
-
- np.int64: ms_service_pb2.MS_INT64,
- np.uint64: ms_service_pb2.MS_UINT64,
- np.float16: ms_service_pb2.MS_FLOAT16,
- np.float32: ms_service_pb2.MS_FLOAT32,
- np.float64: ms_service_pb2.MS_FLOAT64,
- }
- for k, v in dtype_map.items():
- if k == data.dtype:
- tensor.dtype = v
- break
- if tensor.dtype == ms_service_pb2.MS_UNKNOWN:
- raise RuntimeError("Unknown data type", data.dtype)
- tensor.data = data.tobytes()
- return tensor
-
-
- def _create_scalar_tensor(vals, tensor=None):
- if not isinstance(vals, (tuple, list)):
- vals = (vals,)
- return _create_tensor(np.array(vals), tensor)
-
-
- def _create_bytes_tensor(bytes_vals, tensor=None):
- if tensor is None:
- tensor = ms_service_pb2.Tensor()
-
- if not isinstance(bytes_vals, (tuple, list)):
- bytes_vals = (bytes_vals,)
- tensor.shape.dims.extend([len(bytes_vals)])
- tensor.dtype = ms_service_pb2.MS_BYTES
- for item in bytes_vals:
- tensor.bytes_val.append(item)
- return tensor
-
-
- def _create_str_tensor(str_vals, tensor=None):
- if tensor is None:
- tensor = ms_service_pb2.Tensor()
-
- if not isinstance(str_vals, (tuple, list)):
- str_vals = (str_vals,)
- tensor.shape.dims.extend([len(str_vals)])
- tensor.dtype = ms_service_pb2.MS_STRING
- for item in str_vals:
- tensor.bytes_val.append(bytes(item, encoding="utf8"))
- return tensor
-
-
- def _create_numpy_from_tensor(tensor):
- dtype_map = {
- ms_service_pb2.MS_BOOL: np.bool,
- ms_service_pb2.MS_INT8: np.int8,
- ms_service_pb2.MS_UINT8: np.uint8,
- ms_service_pb2.MS_INT16: ms_service_pb2.MS_INT16,
- ms_service_pb2.MS_UINT16: np.uint16,
- ms_service_pb2.MS_INT32: np.int32,
- ms_service_pb2.MS_UINT32: np.uint32,
-
- ms_service_pb2.MS_INT64: np.int64,
- ms_service_pb2.MS_UINT64: np.uint64,
- ms_service_pb2.MS_FLOAT16: np.float16,
- ms_service_pb2.MS_FLOAT32: np.float32,
- ms_service_pb2.MS_FLOAT64: np.float64,
- }
- if tensor.dtype == ms_service_pb2.MS_STRING or tensor.dtype == ms_service_pb2.MS_BYTES:
- result = []
- for item in tensor.bytes_val:
- if tensor.dtype == ms_service_pb2.MS_STRING:
- result.append(bytes.decode(item))
- else:
- result.append(item)
- if len(result) == 1:
- return result[0]
- return result
-
- result = np.frombuffer(tensor.data, dtype_map[tensor.dtype]).reshape(tensor.shape.dims)
- if not tensor.shape.dims or (len(tensor.shape.dims) == 1 and tensor.shape.dims[0] == 1):
- result = result.reshape((1,))[0]
- return result
-
-
- class Client:
- def __init__(self, ip, port, servable_name, method_name, version_number=None):
- '''
- Create Client connect to rving
- :param ip: serving ip
- :param port: serving port
- :param servable_name: the name of servable supplied by serving
- :param method_name: method supplied by servable
- :param version_number: the version number of servable, default None.
- None meaning the maximum version number in all running versions.
- '''
- self.ip = ip
- self.port = port
- self.servable_name = servable_name
- self.method_name = method_name
- self.version_number = version_number
- channel_str = str(ip) + ":" + str(port)
- channel = grpc.insecure_channel(channel_str)
- self.stub = ms_service_pb2_grpc.MSServiceStub(channel)
-
- def _create_request(self):
- request = ms_service_pb2.PredictRequest()
- request.servable_spec.name = self.servable_name
- request.servable_spec.method_name = self.method_name
- if self.version_number is not None:
- request.servable_spec.version_number = self.version_number
- return request
-
- def _create_instance(self, **kwargs):
- instance = ms_service_pb2.Instance()
- for k, w in kwargs.items():
- tensor = instance.items[k]
- if isinstance(w, (np.ndarray, np.number)):
- _create_tensor(w, tensor)
- elif isinstance(w, str):
- _create_str_tensor(w, tensor)
- elif isinstance(w, (bool, int, float)):
- _create_scalar_tensor(w, tensor)
- elif isinstance(w, bytes):
- _create_bytes_tensor(w, tensor)
- else:
- raise RuntimeError("Not support value type " + str(type(w)))
- return instance
-
- def _paser_result(self, result):
- error_msg_len = len(result.error_msg)
- if error_msg_len == 1:
- return {"error": bytes.decode(result.error_msg[0].error_msg)}
- ret_val = []
- instance_len = len(result.instances)
- if error_msg_len not in (0, instance_len):
- raise RuntimeError(
- "error msg result size " + error_msg_len + " not be 0,1 or length of instances " + str(
- instance_len))
- for i in range(instance_len):
- instance = result.instances[i]
- if error_msg_len == 0 or result.error_msg[i].error_code == 0:
- instance_map = {}
- for k, w in instance.items.items():
- instance_map[k] = _create_numpy_from_tensor(w)
- ret_val.append(instance_map)
- else:
- ret_val.append({"error": bytes.decode(result.error_msg[i].error_msg)})
- return ret_val
-
- def infer(self, instances):
- if not isinstance(instances, (tuple, list)):
- instances = (instances,)
- request = self._create_request()
- for item in instances:
- if isinstance(item, dict):
- request.instances.append(self._create_instance(**item))
- else:
- raise RuntimeError("instance should be a map")
-
- try:
- result = self.stub.Predict(request)
- return self._paser_result(result)
-
- except grpc.RpcError as e:
- print(e.details())
- status_code = e.code()
- print(status_code.name)
- print(status_code.value)
- return {"error": status_code.value}
|