# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import grpc import numpy as np import mindspore_serving.proto.ms_service_pb2 as ms_service_pb2 import mindspore_serving.proto.ms_service_pb2_grpc as ms_service_pb2_grpc def _create_tensor(data, tensor=None): if tensor is None: tensor = ms_service_pb2.Tensor() tensor.shape.dims.extend(data.shape) dtype_map = { np.bool: ms_service_pb2.MS_BOOL, np.int8: ms_service_pb2.MS_INT8, np.uint8: ms_service_pb2.MS_UINT8, np.int16: ms_service_pb2.MS_INT16, np.uint16: ms_service_pb2.MS_UINT16, np.int32: ms_service_pb2.MS_INT32, np.uint32: ms_service_pb2.MS_UINT32, np.int64: ms_service_pb2.MS_INT64, np.uint64: ms_service_pb2.MS_UINT64, np.float16: ms_service_pb2.MS_FLOAT16, np.float32: ms_service_pb2.MS_FLOAT32, np.float64: ms_service_pb2.MS_FLOAT64, } for k, v in dtype_map.items(): if k == data.dtype: tensor.dtype = v break if tensor.dtype == ms_service_pb2.MS_UNKNOWN: raise RuntimeError("Unknown data type", data.dtype) tensor.data = data.tobytes() return tensor def _create_scalar_tensor(vals, tensor=None): if not isinstance(vals, (tuple, list)): vals = (vals,) return _create_tensor(np.array(vals), tensor) def _create_bytes_tensor(bytes_vals, tensor=None): if tensor is None: tensor = ms_service_pb2.Tensor() if not isinstance(bytes_vals, (tuple, list)): bytes_vals = (bytes_vals,) tensor.shape.dims.extend([len(bytes_vals)]) tensor.dtype = ms_service_pb2.MS_BYTES for item in bytes_vals: tensor.bytes_val.append(item) return tensor def _create_str_tensor(str_vals, tensor=None): if tensor is None: tensor = ms_service_pb2.Tensor() if not isinstance(str_vals, (tuple, list)): str_vals = (str_vals,) tensor.shape.dims.extend([len(str_vals)]) tensor.dtype = ms_service_pb2.MS_STRING for item in str_vals: tensor.bytes_val.append(bytes(item, encoding="utf8")) return tensor def _create_numpy_from_tensor(tensor): dtype_map = { ms_service_pb2.MS_BOOL: np.bool, ms_service_pb2.MS_INT8: np.int8, ms_service_pb2.MS_UINT8: np.uint8, ms_service_pb2.MS_INT16: ms_service_pb2.MS_INT16, ms_service_pb2.MS_UINT16: np.uint16, ms_service_pb2.MS_INT32: np.int32, ms_service_pb2.MS_UINT32: np.uint32, ms_service_pb2.MS_INT64: np.int64, ms_service_pb2.MS_UINT64: np.uint64, ms_service_pb2.MS_FLOAT16: np.float16, ms_service_pb2.MS_FLOAT32: np.float32, ms_service_pb2.MS_FLOAT64: np.float64, } if tensor.dtype == ms_service_pb2.MS_STRING or tensor.dtype == ms_service_pb2.MS_BYTES: result = [] for item in tensor.bytes_val: if tensor.dtype == ms_service_pb2.MS_STRING: result.append(bytes.decode(item)) else: result.append(item) if len(result) == 1: return result[0] return result result = np.frombuffer(tensor.data, dtype_map[tensor.dtype]).reshape(tensor.shape.dims) if not tensor.shape.dims or (len(tensor.shape.dims) == 1 and tensor.shape.dims[0] == 1): result = result.reshape((1,))[0] return result class Client: def __init__(self, ip, port, servable_name, method_name, version_number=None): ''' Create Client connect to rving :param ip: serving ip :param port: serving port :param servable_name: the name of servable supplied by serving :param method_name: method supplied by servable :param version_number: the version number of servable, default None. None meaning the maximum version number in all running versions. ''' self.ip = ip self.port = port self.servable_name = servable_name self.method_name = method_name self.version_number = version_number channel_str = str(ip) + ":" + str(port) channel = grpc.insecure_channel(channel_str) self.stub = ms_service_pb2_grpc.MSServiceStub(channel) def _create_request(self): request = ms_service_pb2.PredictRequest() request.servable_spec.name = self.servable_name request.servable_spec.method_name = self.method_name if self.version_number is not None: request.servable_spec.version_number = self.version_number return request def _create_instance(self, **kwargs): instance = ms_service_pb2.Instance() for k, w in kwargs.items(): tensor = instance.items[k] if isinstance(w, (np.ndarray, np.number)): _create_tensor(w, tensor) elif isinstance(w, str): _create_str_tensor(w, tensor) elif isinstance(w, (bool, int, float)): _create_scalar_tensor(w, tensor) elif isinstance(w, bytes): _create_bytes_tensor(w, tensor) else: raise RuntimeError("Not support value type " + str(type(w))) return instance def _paser_result(self, result): error_msg_len = len(result.error_msg) if error_msg_len == 1: return {"error": bytes.decode(result.error_msg[0].error_msg)} ret_val = [] instance_len = len(result.instances) if error_msg_len not in (0, instance_len): raise RuntimeError( "error msg result size " + error_msg_len + " not be 0,1 or length of instances " + str( instance_len)) for i in range(instance_len): instance = result.instances[i] if error_msg_len == 0 or result.error_msg[i].error_code == 0: instance_map = {} for k, w in instance.items.items(): instance_map[k] = _create_numpy_from_tensor(w) ret_val.append(instance_map) else: ret_val.append({"error": bytes.decode(result.error_msg[i].error_msg)}) return ret_val def infer(self, instances): if not isinstance(instances, (tuple, list)): instances = (instances,) request = self._create_request() for item in instances: if isinstance(item, dict): request.instances.append(self._create_instance(**item)) else: raise RuntimeError("instance should be a map") try: result = self.stub.Predict(request) return self._paser_result(result) except grpc.RpcError as e: print(e.details()) status_code = e.code() print(status_code.name) print(status_code.value) return {"error": status_code.value}