| @@ -16,39 +16,74 @@ from ..ops.special import Const | |||
| from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||
| def dtype_promotion(raw_inputs): | |||
| def add_dtype(i): | |||
| if type(i) == int: | |||
| return np.array(i, dtype=np.int32) | |||
| if type(i) == float: | |||
| return np.array(i, dtype=np.float32) | |||
| if type(i) == bool: | |||
| return np.array(i, dtype=np.bool_) | |||
| return None | |||
| scalar_inputs = [ | |||
| add_dtype(i) for i in raw_inputs if not hasattr(i, "dtype") and add_dtype(i) | |||
| ] | |||
| inputs = [i for i in raw_inputs if hasattr(i, "dtype")] | |||
| assert len(scalar_inputs + inputs) > 0 | |||
| dtype = None | |||
| if len(inputs) > 0: | |||
| dtype = np.result_type(*inputs) | |||
| dtype_all = np.result_type(*(inputs + scalar_inputs)) | |||
| assert ( | |||
| dtype != np.float64 and dtype != np.int64 | |||
| ), "unsupport dtype {} by dtype_promotion, please use explict type convert".format( | |||
| dtype | |||
| ) | |||
| if dtype_all == np.bool_: | |||
| for i in raw_inputs: | |||
| if not hasattr(i, "dtype") or i.dtype != np.bool_: | |||
| raise TypeError( | |||
| "bool dtype can not be operated with an element without bool dtype" | |||
| ) | |||
| if dtype_all == np.float64: | |||
| dtype_all = np.float32 | |||
| return dtype_all | |||
| def dtype_promotion(inputs): | |||
| """ | |||
| Returns the dtype that would result from performing an arithmetic | |||
| operation on the provided input tensors and scalars. | |||
| """ | |||
| # map numpy.dtype.kind to priority | |||
| category_priority = { | |||
| "f": 3, # floating-point | |||
| "i": 2, # signed integer | |||
| "u": 2, # unsigned integer | |||
| "b": 1, # boolean | |||
| } | |||
| def scalar2dtype(x): | |||
| """ | |||
| For scalar `x`, returns its corresponding type. A floating point scalar | |||
| has dtype 'float32'. An integral non-boolean scalar has dtype 'int32'. | |||
| A boolean scalar has dtype 'bool'. | |||
| """ | |||
| if isinstance(x, bool): | |||
| return np.bool_ | |||
| if isinstance(x, int): | |||
| return np.int32 | |||
| if isinstance(x, float): | |||
| return np.float32 | |||
| def promote_types(types, cat): | |||
| """ | |||
| Returns the data type with sufficient size to hold all types of | |||
| category `cat` in the list `types`. | |||
| """ | |||
| used_types = [ | |||
| i for i in types if category_priority.get(np.dtype(i).kind, 0) == cat | |||
| ] | |||
| assert len(used_types) > 0 | |||
| res = used_types[0] | |||
| for i in used_types: | |||
| res = np.promote_types(res, i) | |||
| return res | |||
| def max_priority(types): | |||
| """ | |||
| Returns the maximum value of the priority of each type in the list | |||
| `types`. | |||
| """ | |||
| if not types: | |||
| return 0 | |||
| else: | |||
| return max([category_priority.get(np.dtype(i).kind, 0) for i in types]) | |||
| scalars = [] | |||
| tensors = [] | |||
| for data in inputs: | |||
| if hasattr(data, "dtype"): | |||
| tensors.append(data.dtype) | |||
| elif isinstance(data, (float, int, bool)): | |||
| scalars.append(scalar2dtype(data)) | |||
| max_pri_scalars = max_priority(scalars) | |||
| max_pri_tensors = max_priority(tensors) | |||
| assert max_pri_scalars > 0 or max_pri_tensors > 0 | |||
| if max_pri_scalars > max_pri_tensors: | |||
| return promote_types(scalars, max_pri_scalars) | |||
| else: | |||
| return promote_types(tensors, max_pri_tensors) | |||
| def get_device(inputs): | |||