GitOrigin-RevId: baef3d348c
tags/v1.2.0
| @@ -8,7 +8,7 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import numpy as np | import numpy as np | ||||
| from .._imperative_rt.core2 import Tensor | |||||
| # from .._imperative_rt.core2 import Tensor | |||||
| from ..tensor.core import OpBase, TensorBase, apply | from ..tensor.core import OpBase, TensorBase, apply | ||||
| @@ -19,5 +19,10 @@ class Const: | |||||
| self.device = device | self.device = device | ||||
| def __call__(self, *reference): | def __call__(self, *reference): | ||||
| Wrapper = type(reference[0]) | |||||
| return (Wrapper(self.value, self.dtype, self.device, True),) | |||||
| from ...tensor import Tensor | |||||
| device = self.device | |||||
| if device is None: | |||||
| device = reference[0].device | |||||
| return (Tensor(self.value, self.dtype, self.device, True),) | |||||
| @@ -13,6 +13,12 @@ import numpy as np | |||||
| # normal dtype related | # normal dtype related | ||||
| from .._imperative_rt import bfloat16, intb1, intb2, intb4 | from .._imperative_rt import bfloat16, intb1, intb2, intb4 | ||||
| from .._imperative_rt.common import ( | |||||
| get_scale, | |||||
| get_zero_point, | |||||
| is_dtype_equal, | |||||
| is_quantize, | |||||
| ) | |||||
| def is_lowbit(dtype): | def is_lowbit(dtype): | ||||
| @@ -42,41 +48,6 @@ _metadata_dict = { | |||||
| } | } | ||||
| def is_quantize(dtype): | |||||
| return ( | |||||
| hasattr(dtype, "metadata") | |||||
| and dtype.metadata is not None | |||||
| and "mgb_dtype" in dtype.metadata | |||||
| ) | |||||
| def get_scale(dtype): | |||||
| assert is_quantize(dtype) | |||||
| return dtype.metadata["mgb_dtype"]["scale"] | |||||
| def get_zero_point(dtype): | |||||
| assert is_quantize(dtype) | |||||
| metadata = dtype.metadata["mgb_dtype"] | |||||
| assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm") | |||||
| return metadata["zero_point"] | |||||
| def is_equal(dt0, dt1): | |||||
| def _get_zero_point(dtype): | |||||
| assert is_quantize(dtype) | |||||
| metadata = dtype.metadata["mgb_dtype"] | |||||
| return metadata.get("zero_point") | |||||
| if is_quantize(dt0) and is_quantize(dt1): | |||||
| return get_scale(dt0) == get_scale(dt1) and _get_zero_point( | |||||
| dt0 | |||||
| ) == _get_zero_point(dt1) | |||||
| if not (is_quantize(dt0) or is_quantize(dt1)): | |||||
| return dt0 == dt1 | |||||
| return False | |||||
| def _check_zero_point(zp: int, dtype_str: str): | def _check_zero_point(zp: int, dtype_str: str): | ||||
| qmin = _metadata_dict[dtype_str].qmin | qmin = _metadata_dict[dtype_str].qmin | ||||
| qmax = _metadata_dict[dtype_str].qmax | qmax = _metadata_dict[dtype_str].qmax | ||||
| @@ -151,9 +151,9 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
| def get_index(i): | def get_index(i): | ||||
| if not isinstance(i, (Tensor)): | if not isinstance(i, (Tensor)): | ||||
| if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | ||||
| (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | |||||
| (i,) = Const(i, dtype=np.bool_, device=inp.device)() | |||||
| else: | else: | ||||
| (i,) = Const(i, dtype=np.int32, device=inp.device)(inp) | |||||
| (i,) = Const(i, dtype=np.int32, device=inp.device)() | |||||
| return i | return i | ||||
| assert isinstance(i, Tensor) | assert isinstance(i, Tensor) | ||||
| if i.dtype != np.bool_: | if i.dtype != np.bool_: | ||||
| @@ -197,7 +197,7 @@ def try_condtake(tensor, index): | |||||
| ): | ): | ||||
| return [] | return [] | ||||
| if isinstance(index, np.ndarray): | if isinstance(index, np.ndarray): | ||||
| (index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) | |||||
| (index,) = Const(index, dtype=np.bool_, device=tensor.device)() | |||||
| assert isinstance(index, Tensor) | assert isinstance(index, Tensor) | ||||
| if not isinstance(tensor, Tensor): | if not isinstance(tensor, Tensor): | ||||
| raise TypeError("input must be a tensor") | raise TypeError("input must be a tensor") | ||||
| @@ -217,9 +217,7 @@ def getitem(tensor, index): | |||||
| if isinstance(v.shape, v.__class__): | if isinstance(v.shape, v.__class__): | ||||
| break | break | ||||
| if len(v.shape) > 0 and v.shape[0] == 0: | if len(v.shape) > 0 and v.shape[0] == 0: | ||||
| (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( | |||||
| tensor | |||||
| ) | |||||
| (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)() | |||||
| return empty_tensor | return empty_tensor | ||||
| if use_subtensor: | if use_subtensor: | ||||
| op = builtin.Subtensor(items=items) | op = builtin.Subtensor(items=items) | ||||
| @@ -240,8 +238,7 @@ def setitem(tensor, index, value): | |||||
| return tensor | return tensor | ||||
| tensor = tensor.reshape(-1) | tensor = tensor.reshape(-1) | ||||
| if not isinstance(value, Tensor): | if not isinstance(value, Tensor): | ||||
| op = Const(value, dtype=tensor.dtype, device=tensor.device) | |||||
| (value,) = op(tensor) | |||||
| (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)() | |||||
| tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | ||||
| for v in tensors: | for v in tensors: | ||||
| if len(v.shape) > 0 and v.shape[0] == 0: | if len(v.shape) > 0 and v.shape[0] == 0: | ||||
| @@ -11,10 +11,10 @@ from typing import Iterable, Union | |||||
| import numpy as np | import numpy as np | ||||
| from .._imperative_rt.core2 import Tensor, apply | |||||
| from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device | |||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from ..ops.special import Const | from ..ops.special import Const | ||||
| from .dtype import is_equal, is_quantize | |||||
| from .dtype import is_dtype_equal, is_quantize | |||||
| from .megbrain_graph import VarNode | from .megbrain_graph import VarNode | ||||
| _enable_convert_inputs = True | _enable_convert_inputs = True | ||||
| @@ -37,94 +37,12 @@ def set_convert_inputs(flag): | |||||
| return backup | return backup | ||||
| def dtype_promotion(inputs): | |||||
| """ | |||||
| Returns the dtype that would result from performing an arithmetic | |||||
| operation on the provided input tensors and scalars. | |||||
| """ | |||||
| # map numpy.dtype.kind to priority | |||||
| category_priority = { | |||||
| "f": 3, # floating-point | |||||
| "i": 2, # signed integer | |||||
| "u": 2, # unsigned integer | |||||
| "b": 1, # boolean | |||||
| } | |||||
| def scalar2dtype(x): | |||||
| """ | |||||
| For scalar `x`, returns its corresponding type. A floating point scalar | |||||
| has dtype 'float32'. An integral non-boolean scalar has dtype 'int32'. | |||||
| A boolean scalar has dtype 'bool'. | |||||
| """ | |||||
| if isinstance(x, bool): | |||||
| return np.bool_ | |||||
| if isinstance(x, int): | |||||
| return np.int32 | |||||
| if isinstance(x, float): | |||||
| return np.float32 | |||||
| def promote_types(types, cat): | |||||
| """ | |||||
| Returns the data type with sufficient size to hold all types of | |||||
| category `cat` in the list `types`. | |||||
| """ | |||||
| used_types = [ | |||||
| i for i in types if category_priority.get(np.dtype(i).kind, 0) == cat | |||||
| ] | |||||
| assert len(used_types) > 0 | |||||
| res = used_types[0] | |||||
| for i in used_types: | |||||
| res = np.promote_types(res, i) | |||||
| return res | |||||
| def max_priority(types): | |||||
| """ | |||||
| Returns the maximum value of the priority of each type in the list | |||||
| `types`. | |||||
| """ | |||||
| if not types: | |||||
| return 0 | |||||
| else: | |||||
| return max([category_priority.get(np.dtype(i).kind, 0) for i in types]) | |||||
| scalars = [] | |||||
| tensors = [] | |||||
| for data in inputs: | |||||
| if hasattr(data, "dtype"): | |||||
| tensors.append(data.dtype) | |||||
| elif isinstance(data, (float, int, bool)): | |||||
| scalars.append(scalar2dtype(data)) | |||||
| max_pri_scalars = max_priority(scalars) | |||||
| max_pri_tensors = max_priority(tensors) | |||||
| assert max_pri_scalars > 0 or max_pri_tensors > 0 | |||||
| if max_pri_scalars > max_pri_tensors: | |||||
| return promote_types(scalars, max_pri_scalars) | |||||
| else: | |||||
| return promote_types(tensors, max_pri_tensors) | |||||
| def get_device(inputs): | |||||
| device = None | |||||
| for i in inputs: | |||||
| if isinstance(i, (Tensor, VarNode)): | |||||
| if device is None: | |||||
| device = i.device | |||||
| elif device != i.device: | |||||
| raise ValueError("ambiguous device: {} vs {}".format(device, i.device)) | |||||
| assert device is not None | |||||
| return device | |||||
| def concatenate(inputs, axis=0, *, device=None): | def concatenate(inputs, axis=0, *, device=None): | ||||
| dtype = dtype_promotion(inputs) | dtype = dtype_promotion(inputs) | ||||
| device = get_device(inputs) | device = get_device(inputs) | ||||
| def convert(x): | def convert(x): | ||||
| return convert_single_value(x, inputs, dtype=dtype) | |||||
| return convert_single_value(x, dtype=dtype, device=device) | |||||
| inputs = tuple(map(convert, inputs)) | inputs = tuple(map(convert, inputs)) | ||||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) | (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) | ||||
| @@ -133,7 +51,7 @@ def concatenate(inputs, axis=0, *, device=None): | |||||
| def astype(x, dtype): | def astype(x, dtype): | ||||
| dtype = np.dtype(dtype) | dtype = np.dtype(dtype) | ||||
| if not is_equal(x.dtype, dtype): | |||||
| if not is_dtype_equal(x.dtype, dtype): | |||||
| isscalar = x.isscalar() | isscalar = x.isscalar() | ||||
| (x,) = apply(builtin.TypeCvt(dtype=dtype), x) | (x,) = apply(builtin.TypeCvt(dtype=dtype), x) | ||||
| if isscalar: | if isscalar: | ||||
| @@ -141,13 +59,12 @@ def astype(x, dtype): | |||||
| return x | return x | ||||
| def convert_single_value(v, inputs, *, dtype=None, device=None): | |||||
| tensors = [i for i in inputs if isinstance(i, (Tensor, VarNode))] | |||||
| assert len(tensors) > 0 | |||||
| def convert_single_value(v, *, dtype=None, device=None): | |||||
| if isinstance(v, (Tensor, VarNode)): | if isinstance(v, (Tensor, VarNode)): | ||||
| v = astype(v, v.dtype if is_quantize(v.dtype) else dtype) | |||||
| if not is_quantize(v.dtype): | |||||
| v = astype(v, dtype) | |||||
| else: | else: | ||||
| (v,) = Const(v, dtype=dtype, device=device)(*tensors) | |||||
| (v,) = Const(v, dtype=dtype, device=device)() | |||||
| return v | return v | ||||
| @@ -161,7 +78,7 @@ def convert_inputs(*args: Tensor): | |||||
| def convert(value): | def convert(value): | ||||
| if value is None: | if value is None: | ||||
| return value | return value | ||||
| return convert_single_value(value, args, dtype=dtype, device=device) | |||||
| return convert_single_value(value, dtype=dtype, device=device) | |||||
| return tuple(map(convert, args)) | return tuple(map(convert, args)) | ||||
| @@ -703,7 +703,7 @@ def topk( | |||||
| op = builtin.TopK(mode=mode) | op = builtin.TopK(mode=mode) | ||||
| if not isinstance(k, Tensor): | if not isinstance(k, Tensor): | ||||
| (k,) = Const(k, dtype="int32", device=inp.device)(inp) | |||||
| (k,) = Const(k, dtype="int32", device=inp.device)() | |||||
| if len(inp.shape) == 1: | if len(inp.shape) == 1: | ||||
| inp = inp.reshape(1, -1) | inp = inp.reshape(1, -1) | ||||
| @@ -658,7 +658,7 @@ def batch_norm( | |||||
| def make_full_if_none(x, value): | def make_full_if_none(x, value): | ||||
| if x is None: | if x is None: | ||||
| (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | |||||
| (x,) = Const(value, dtype=inp.dtype, device=inp.device)() | |||||
| shape = utils.astensor1d( | shape = utils.astensor1d( | ||||
| (1, C, 1, 1), inp, dtype="int32", device=inp.device | (1, C, 1, 1), inp, dtype="int32", device=inp.device | ||||
| ) | ) | ||||
| @@ -1567,7 +1567,7 @@ def indexing_one_hot( | |||||
| """ | """ | ||||
| assert isinstance(src, Tensor), "src must be of Tensor type" | assert isinstance(src, Tensor), "src must be of Tensor type" | ||||
| op = builtin.IndexingOneHot(axis=axis) | op = builtin.IndexingOneHot(axis=axis) | ||||
| index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device) | |||||
| index = utils.convert_single_value(index, dtype="int32", device=src.device) | |||||
| (result,) = apply(op, src, index) | (result,) = apply(op, src, index) | ||||
| if not keepdims: | if not keepdims: | ||||
| result = squeeze(result, axis) | result = squeeze(result, axis) | ||||
| @@ -107,9 +107,7 @@ def full(shape, value, dtype="float32", device=None): | |||||
| shape = (shape,) | shape = (shape,) | ||||
| if device is None: | if device is None: | ||||
| device = get_default_device() | device = get_default_device() | ||||
| (x,) = Const(value, dtype=dtype, device=device)( | |||||
| Tensor(value, dtype=dtype, device=device) | |||||
| ) | |||||
| (x,) = Const(value, dtype=dtype, device=device)() | |||||
| if len(shape) == 0: # scalar | if len(shape) == 0: # scalar | ||||
| return x | return x | ||||
| return broadcast_to(x, shape) | return broadcast_to(x, shape) | ||||
| @@ -265,7 +263,7 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | |||||
| device = as_device(device) | device = as_device(device) | ||||
| def convert(x): | def convert(x): | ||||
| return convert_single_value(x, inps, dtype=dtype) | |||||
| return convert_single_value(x, dtype=dtype, device=device) | |||||
| inps = tuple(map(convert, inps)) | inps = tuple(map(convert, inps)) | ||||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) | (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) | ||||
| @@ -37,8 +37,10 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
| else: | else: | ||||
| cn = CompNode(device) | cn = CompNode(device) | ||||
| else: | else: | ||||
| assert isinstance(device, CompNode) | |||||
| cn = device | |||||
| if isinstance(device, CompNode): | |||||
| cn = device | |||||
| else: | |||||
| cn = device._cn | |||||
| # import pdb; pdb.set_trace() | # import pdb; pdb.set_trace() | ||||
| if isinstance(data, _Tensor): | if isinstance(data, _Tensor): | ||||
| @@ -179,4 +179,5 @@ void init_common(py::module m) { | |||||
| init_npy_num_bfloat16(m); | init_npy_num_bfloat16(m); | ||||
| init_npy_num_intbx(m); | init_npy_num_intbx(m); | ||||
| init_dtypes(m); | |||||
| } | } | ||||
| @@ -158,7 +158,7 @@ void PyExceptionForward::throw_() { | |||||
| /* ============== namespace npy ============== */ | /* ============== namespace npy ============== */ | ||||
| namespace { | |||||
| namespace npy { | |||||
| int to_mgb_supported_dtype_raw(int dtype) { | int to_mgb_supported_dtype_raw(int dtype) { | ||||
| if (dtype == NPY_INT64) | if (dtype == NPY_INT64) | ||||
| @@ -199,12 +199,6 @@ int dtype_mgb2np_raw(DType dtype) { | |||||
| "can not convert dtype %s to numpy dtype", dtype.name())); | "can not convert dtype %s to numpy dtype", dtype.name())); | ||||
| } | } | ||||
| struct PyArrayDescrDeleter { | |||||
| void operator()(PyArray_Descr* obj) { | |||||
| Py_XDECREF(obj); | |||||
| } | |||||
| }; | |||||
| //! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new | //! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new | ||||
| //! reference to the descriptor. | //! reference to the descriptor. | ||||
| std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr( | std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr( | ||||
| @@ -585,9 +579,7 @@ void ndarray_shared_from_tensor_py_capsule_dtor(PyObject *cap) { | |||||
| HostTensorNDRefHolder::free(static_cast<HostTensorNDRefHolder*>(ptr)); | HostTensorNDRefHolder::free(static_cast<HostTensorNDRefHolder*>(ptr)); | ||||
| } | } | ||||
| } // anonymous namespace | |||||
| PyObject* npy::ndarray_from_tensor( | |||||
| PyObject* ndarray_from_tensor( | |||||
| const HostTensorND &val, ShareType share_type) { | const HostTensorND &val, ShareType share_type) { | ||||
| if (!val.layout().is_contiguous() && !val.shape().is_empty()) { | if (!val.layout().is_contiguous() && !val.shape().is_empty()) { | ||||
| mgb_assert(share_type != ShareType::MUST_SHARE); | mgb_assert(share_type != ShareType::MUST_SHARE); | ||||
| @@ -634,7 +626,7 @@ PyObject* npy::ndarray_from_tensor( | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| HostTensorND npy::np2tensor(PyObject* obj, const Meth& meth, DType dtype) { | |||||
| HostTensorND np2tensor(PyObject* obj, const Meth& meth, DType dtype) { | |||||
| auto ret_full = np2tensor_try_borrow(obj, meth, dtype); | auto ret_full = np2tensor_try_borrow(obj, meth, dtype); | ||||
| if (meth.must_borrow_) { | if (meth.must_borrow_) { | ||||
| mgb_assert(ret_full.second, | mgb_assert(ret_full.second, | ||||
| @@ -645,7 +637,7 @@ HostTensorND npy::np2tensor(PyObject* obj, const Meth& meth, DType dtype) { | |||||
| return ret_full.first; | return ret_full.first; | ||||
| } | } | ||||
| PyObject* npy::dtype_mgb2np(mgb::DType dtype) { | |||||
| PyObject* dtype_mgb2np(mgb::DType dtype) { | |||||
| PYTHON_GIL; | PYTHON_GIL; | ||||
| // According to | // According to | ||||
| // https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_TypeObjectFromType | // https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_TypeObjectFromType | ||||
| @@ -668,7 +660,7 @@ PyObject* npy::dtype_mgb2np(mgb::DType dtype) { | |||||
| return typeobj; | return typeobj; | ||||
| } | } | ||||
| mgb::DType npy::dtype_np2mgb(PyObject *obj) { | |||||
| mgb::DType dtype_np2mgb(PyObject *obj) { | |||||
| mgb_assert(obj && obj != Py_None, | mgb_assert(obj && obj != Py_None, | ||||
| "can not convert null PyObject to numpy dtype"); | "can not convert null PyObject to numpy dtype"); | ||||
| // see | // see | ||||
| @@ -686,7 +678,7 @@ mgb::DType npy::dtype_np2mgb(PyObject *obj) { | |||||
| return result; | return result; | ||||
| } | } | ||||
| PyObject* npy::to_mgb_supported_dtype(PyObject* dtype) { | |||||
| PyObject* to_mgb_supported_dtype(PyObject* dtype) { | |||||
| PYTHON_GIL; | PYTHON_GIL; | ||||
| PyArray_Descr* descr; | PyArray_Descr* descr; | ||||
| @@ -702,7 +694,7 @@ PyObject* npy::to_mgb_supported_dtype(PyObject* dtype) { | |||||
| return PyArray_TypeObjectFromType(type_num); | return PyArray_TypeObjectFromType(type_num); | ||||
| } | } | ||||
| TensorShape npy::vec2shape(const std::vector<size_t> &vec) { | |||||
| TensorShape vec2shape(const std::vector<size_t> &vec) { | |||||
| TensorShape shape; | TensorShape shape; | ||||
| mgb_assert(vec.size() <= TensorShape::MAX_NDIM, | mgb_assert(vec.size() <= TensorShape::MAX_NDIM, | ||||
| "dim too large: %zd (max %zd)", | "dim too large: %zd (max %zd)", | ||||
| @@ -718,3 +710,5 @@ TensorShape npy::vec2shape(const std::vector<size_t> &vec) { | |||||
| mgb_assert(shape.ndim, "shape should not be empty"); | mgb_assert(shape.ndim, "shape should not be empty"); | ||||
| return shape; | return shape; | ||||
| } | } | ||||
| } // namespace npy | |||||
| @@ -11,7 +11,7 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megbrain/graph.h" | |||||
| #include "megbrain/common.h" | |||||
| #include "megbrain/utils/persistent_cache.h" | #include "megbrain/utils/persistent_cache.h" | ||||
| #include "megbrain/imperative/op_def.h" | #include "megbrain/imperative/op_def.h" | ||||
| @@ -26,6 +26,8 @@ | |||||
| #include <pybind11/numpy.h> | #include <pybind11/numpy.h> | ||||
| #include <pybind11/functional.h> | #include <pybind11/functional.h> | ||||
| #include "./numpy_dtypes.h" | |||||
| pybind11::module submodule(pybind11::module parent, const char* name, const char* doc = nullptr); | pybind11::module submodule(pybind11::module parent, const char* name, const char* doc = nullptr); | ||||
| pybind11::module rel_import(pybind11::str name, pybind11::module m, int level); | pybind11::module rel_import(pybind11::str name, pybind11::module m, int level); | ||||
| @@ -182,6 +184,18 @@ namespace npy { | |||||
| //! convert raw vector to tensor shape | //! convert raw vector to tensor shape | ||||
| mgb::TensorShape vec2shape(const std::vector<size_t> &vec); | mgb::TensorShape vec2shape(const std::vector<size_t> &vec); | ||||
| struct PyArrayDescrDeleter { | |||||
| void operator()(PyArray_Descr* obj) { | |||||
| Py_XDECREF(obj); | |||||
| } | |||||
| }; | |||||
| //! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new | |||||
| //! reference to the descriptor. | |||||
| std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(mgb::DType dtype); | |||||
| mgb::DType dtype_np2mgb_descr(PyArray_Descr* descr); | |||||
| //! convert megbrain dtype to numpy dtype object; return new reference | //! convert megbrain dtype to numpy dtype object; return new reference | ||||
| PyObject* dtype_mgb2np(mgb::DType dtype); | PyObject* dtype_mgb2np(mgb::DType dtype); | ||||
| @@ -0,0 +1,179 @@ | |||||
| /** | |||||
| * \file imperative/python/src/numpy_dtypes.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "./numpy_dtypes.h" | |||||
| #include "./helper.h" | |||||
| #include "./pyext17.h" | |||||
| #include "pybind11/pybind11.h" | |||||
| #include <cstring> | |||||
| namespace py = pybind11; | |||||
| namespace mgb { | |||||
| namespace { | |||||
| inline bool _is_quantize(PyArray_Descr* dtype) { | |||||
| static PyObject* PY_MGB_DTYPE_KEY = PyUnicode_FromString("mgb_dtype"); | |||||
| return dtype->metadata && | |||||
| PyDict_CheckExact(dtype->metadata) && | |||||
| PyDict_Contains(dtype->metadata, PY_MGB_DTYPE_KEY) == 1; | |||||
| } | |||||
| PyObject* _get_mgb_dtype(PyArray_Descr* dtype) { | |||||
| // Return value: New reference. | |||||
| if (!_is_quantize(dtype)) { | |||||
| throw py::type_error("expact quantize dtype"); | |||||
| } | |||||
| PyObject* ob = PyDict_GetItemString(dtype->metadata, "mgb_dtype"); | |||||
| if (!PyDict_CheckExact(ob)) { | |||||
| throw py::type_error("mgb_dtype is not dict"); | |||||
| } | |||||
| Py_INCREF(ob); | |||||
| return ob; | |||||
| } | |||||
| double _get_scale(PyArray_Descr* dtype) { | |||||
| PyObject* ob = _get_mgb_dtype(dtype); | |||||
| PyObject* scale = PyDict_GetItemString(ob, "scale"); | |||||
| if (!scale) { | |||||
| Py_DECREF(ob); | |||||
| throw py::key_error("scale"); | |||||
| } | |||||
| if (!PyFloat_Check(scale)) { | |||||
| Py_DECREF(ob); | |||||
| throw py::type_error("scale is not float"); | |||||
| } | |||||
| double ret = PyFloat_AsDouble(scale); | |||||
| Py_DECREF(ob); | |||||
| return ret; | |||||
| } | |||||
| long _get_zero_point(PyArray_Descr* dtype) { | |||||
| PyObject* ob = _get_mgb_dtype(dtype); | |||||
| PyObject* name = PyDict_GetItemString(ob, "name"); | |||||
| if (!name) { | |||||
| Py_DECREF(ob); | |||||
| throw py::key_error("name"); | |||||
| } | |||||
| const char* s = PyUnicode_AsUTF8(name); | |||||
| if (strcmp(s, "Quantized8Asymm") != 0 && strcmp(s, "Quantized4Asymm") != 0) { | |||||
| Py_DECREF(ob); | |||||
| throw py::value_error(ssprintf("expect name to be \"Quantized8Asymm\" or \"Quantized4Asymm\", got %s", s)); | |||||
| } | |||||
| PyObject* zp = PyDict_GetItemString(ob, "zero_point"); | |||||
| if (!zp) { | |||||
| Py_DECREF(ob); | |||||
| throw py::key_error("zero_point"); | |||||
| } | |||||
| long ret = PyLong_AsLong(zp); | |||||
| Py_DECREF(ob); | |||||
| return ret; | |||||
| } | |||||
| bool _is_dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2) { | |||||
| bool q1 = _is_quantize(dt1), | |||||
| q2 = _is_quantize(dt2); | |||||
| if (q1 && q2) { | |||||
| if (_get_scale(dt1) != _get_scale(dt2)) { | |||||
| return false; | |||||
| } | |||||
| PyObject* zp1 = PyDict_GetItemString( | |||||
| PyDict_GetItemString(dt1->metadata, "mgb_dtype"), "zero_point"); | |||||
| PyObject* zp2 = PyDict_GetItemString( | |||||
| PyDict_GetItemString(dt2->metadata, "mgb_dtype"), "zero_point"); | |||||
| if (!zp1 || !zp2) { | |||||
| throw py::key_error("zero_point"); | |||||
| } | |||||
| return PyLong_AsLong(zp1) == PyLong_AsLong(zp2); | |||||
| } | |||||
| if (!q1 && !q2) { | |||||
| return dt1->type_num == dt2->type_num; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| template<auto f> | |||||
| struct _wrap { | |||||
| static constexpr size_t n_args = []() { | |||||
| using F = decltype(f); | |||||
| using T = PyArray_Descr*; | |||||
| static_assert(std::is_pointer<F>::value); | |||||
| if constexpr (std::is_invocable<F, T>::value) { | |||||
| return 1; | |||||
| } else if constexpr (std::is_invocable<F, T, T>::value) { | |||||
| return 2; | |||||
| } else { | |||||
| static_assert(!std::is_same_v<F, F>, "unreachable"); | |||||
| } | |||||
| }(); | |||||
| static PyObject* impl(PyObject* self, PyObject*const* args, size_t nargs) { | |||||
| if (nargs != n_args) { | |||||
| PyErr_Format(PyExc_ValueError, "expected %lu arguments", n_args); | |||||
| return nullptr; | |||||
| } | |||||
| for (size_t i=0; i<nargs; ++i) { | |||||
| if (args[i] == Py_None) { | |||||
| PyErr_SetString(PyExc_ValueError, "can not convert null PyObject to numpy dtype"); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| try { | |||||
| PyArray_Descr *dt1; | |||||
| if(!PyArray_DescrConverter(args[0], &dt1)) { | |||||
| throw ConversionError(ssprintf("can not convert to numpy.dtype from %s", | |||||
| args[0]->ob_type->tp_name)); | |||||
| } | |||||
| if constexpr (n_args == 1) { | |||||
| auto res = (*f)(dt1); | |||||
| Py_DECREF(dt1); | |||||
| return py::cast(res).release().ptr(); | |||||
| } else { | |||||
| PyArray_Descr *dt2; | |||||
| if(!PyArray_DescrConverter(args[1], &dt2)) { | |||||
| Py_DECREF(dt1); | |||||
| throw ConversionError(ssprintf("can not convert to numpy.dtype from %s", | |||||
| args[1]->ob_type->tp_name)); | |||||
| } | |||||
| auto&& res = (*f)(dt1, dt2); | |||||
| Py_DECREF(dt1); | |||||
| Py_DECREF(dt2); | |||||
| return py::cast(res).release().ptr(); | |||||
| } | |||||
| } catch (std::exception& e) { | |||||
| PyErr_SetString(PyExc_RuntimeError, e.what()); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| }; | |||||
| } // anonymous namespace | |||||
| void init_dtypes(py::module m) { | |||||
| static PyMethodDef method_defs[] = { | |||||
| {"is_quantize", (PyCFunction)_wrap<&_is_quantize>::impl, METH_FASTCALL, nullptr}, | |||||
| {"get_scale", (PyCFunction)_wrap<&_get_scale>::impl, METH_FASTCALL, nullptr}, | |||||
| {"get_zero_point", (PyCFunction)_wrap<&_get_zero_point>::impl, METH_FASTCALL, nullptr}, | |||||
| {"is_dtype_equal", (PyCFunction)_wrap<&_is_dtype_equal>::impl, METH_FASTCALL, nullptr}, | |||||
| {nullptr, nullptr, 0, nullptr} | |||||
| }; | |||||
| for (auto&& def: method_defs) { | |||||
| if (def.ml_meth != nullptr) { | |||||
| auto* func = PyCFunction_NewEx(&def, nullptr, nullptr); | |||||
| if (!func) throw py::error_already_set(); | |||||
| py::setattr(m, def.ml_name, func); | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace mgb | |||||
| @@ -36,6 +36,7 @@ namespace mgb { | |||||
| int npy_num_intb##n(); | int npy_num_intb##n(); | ||||
| FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX) | FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX) | ||||
| #undef DEFINE_NPY_INTBX | #undef DEFINE_NPY_INTBX | ||||
| void init_dtypes(pybind11::module m); | |||||
| void init_npy_num_intbx(pybind11::module m); | void init_npy_num_intbx(pybind11::module m); | ||||
| //! numpy type num for bfloat16 type | //! numpy type num for bfloat16 type | ||||
| @@ -9,16 +9,22 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #include "megbrain/dtype.h" | |||||
| #include "megbrain/common.h" | |||||
| #include "./tensor.h" | #include "./tensor.h" | ||||
| #include "./grad.h" | #include "./grad.h" | ||||
| #include "./trace.h" | #include "./trace.h" | ||||
| #include "./common.h" | #include "./common.h" | ||||
| #include "./numpy_dtypes.h" | #include "./numpy_dtypes.h" | ||||
| #include "./graph_rt.h" | #include "./graph_rt.h" | ||||
| #include "./helper.h" | |||||
| #include <pybind11/numpy.h> | #include <pybind11/numpy.h> | ||||
| #include <pybind11/operators.h> | #include <pybind11/operators.h> | ||||
| #include "./helper.h" | |||||
| #include <unordered_map> | |||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
| @@ -413,6 +419,198 @@ struct TensorWeakRef { | |||||
| } | } | ||||
| }; | }; | ||||
| /* ============== convert inputs ============== */ | |||||
| // map numpy.dtype.kind to priority | |||||
| inline uint8_t category_priority(char c) { | |||||
| switch (c) { | |||||
| case 'f': return 3; // floating-point | |||||
| case 'i': return 2; // signed integer | |||||
| case 'u': return 2; // unsigned integer | |||||
| case 'b': return 1; // boolean | |||||
| default: return 0; | |||||
| } | |||||
| } | |||||
| // Returns the maximum value of the priority of each type in the list `types`. | |||||
| uint8_t max_priority(SmallVector<PyArray_Descr*> types) { | |||||
| if (types.size() == 0) { | |||||
| return 0; | |||||
| } else { | |||||
| uint8_t max_p = 0; | |||||
| for (auto&& desc: types) { | |||||
| max_p = std::max(max_p, category_priority(desc->kind)); | |||||
| } | |||||
| return max_p; | |||||
| } | |||||
| } | |||||
| // Returns the data type with sufficient size to hold all types of | |||||
| // category `cat` in the list `types`. | |||||
| PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { | |||||
| // Return value: New reference | |||||
| SmallVector<PyArray_Descr*> used_types; | |||||
| for (auto&& desc: types) { | |||||
| auto&& v = category_priority(desc->kind); | |||||
| if (v == cat) { | |||||
| used_types.emplace_back(desc); | |||||
| } | |||||
| } | |||||
| mgb_assert(used_types.size() > 0, "size of used_types is 0"); | |||||
| PyArray_Descr* res = used_types[0]; | |||||
| Py_INCREF(res); | |||||
| for (size_t i = 1; i < used_types.size(); ++i) { | |||||
| PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res); | |||||
| Py_DECREF(res); | |||||
| res = tmp; | |||||
| } | |||||
| return res; | |||||
| } | |||||
| PyArray_Descr* scalar2dtype(PyObject* arg) { | |||||
| // Return value: New reference | |||||
| if (PyBool_Check(arg)) { | |||||
| auto&& descr = PyArray_DescrFromType(NPY_BOOL); | |||||
| return descr; | |||||
| } | |||||
| if (PyLong_CheckExact(arg)) { | |||||
| auto&& descr = PyArray_DescrFromType(NPY_INT32); | |||||
| return descr; | |||||
| } | |||||
| if (PyFloat_CheckExact(arg)) { | |||||
| auto&& descr = PyArray_DescrFromType(NPY_FLOAT32); | |||||
| return descr; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { | |||||
| // Return value: New reference | |||||
| SmallVector<PyArray_Descr*> tensors; | |||||
| SmallVector<PyArray_Descr*> scalars; | |||||
| bool is_tuple = false; | |||||
| PyObject* tuple; | |||||
| if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||||
| if (PyList_Check(args[0])) { | |||||
| tuple = PyList_AsTuple(args[0]); | |||||
| } else { | |||||
| tuple = args[0]; | |||||
| Py_INCREF(tuple); | |||||
| } | |||||
| nargs = PyTuple_Size(tuple); | |||||
| is_tuple = true; | |||||
| } | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||||
| if (handle == Py_None) continue; | |||||
| TensorWrapper* tw = TensorWrapper::cast_safe(handle); | |||||
| if (tw) { | |||||
| mgb::DType type = tw->m_tensor->dtype(); | |||||
| auto&& descr = npy::dtype_mgb2np_descr(type); | |||||
| Py_INCREF(descr.get()); | |||||
| tensors.emplace_back(descr.get()); | |||||
| }else{ | |||||
| if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) { | |||||
| auto&& descr = PyArray_DescrFromObject(handle, nullptr); | |||||
| tensors.emplace_back(descr); | |||||
| continue; | |||||
| } | |||||
| PyArray_Descr* descr = scalar2dtype(handle); | |||||
| if (descr) { | |||||
| scalars.emplace_back(descr); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| } | |||||
| auto max_pri_scalars = max_priority(scalars); | |||||
| auto max_pri_tensors = max_priority(tensors); | |||||
| if (max_pri_scalars <= 0 && max_pri_tensors <= 0) { | |||||
| throw py::value_error("invalid input, no dtype avaliable"); | |||||
| } | |||||
| PyArray_Descr* res; | |||||
| if (max_pri_scalars > max_pri_tensors) { | |||||
| res = promote_types(scalars, max_pri_scalars); | |||||
| }else{ | |||||
| res = promote_types(tensors, max_pri_tensors); | |||||
| } | |||||
| for (auto *p: tensors) { Py_DECREF(p); } | |||||
| for (auto *p: scalars) { Py_DECREF(p); } | |||||
| Py_DECREF(tuple); | |||||
| return res; | |||||
| } | |||||
| CompNode _get_device(PyObject*const* args, size_t nargs) { | |||||
| bool is_tuple = false; | |||||
| PyObject* tuple; | |||||
| if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||||
| if (PyList_Check(args[0])) { | |||||
| tuple = PyList_AsTuple(args[0]); | |||||
| } else { | |||||
| tuple = args[0]; | |||||
| Py_INCREF(tuple); | |||||
| } | |||||
| nargs = PyTuple_Size(tuple); | |||||
| is_tuple = true; | |||||
| } | |||||
| bool valid = false; | |||||
| CompNode cn; | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||||
| TensorWrapper* tw = TensorWrapper::cast_safe(handle); | |||||
| if (tw) { | |||||
| if (!valid) { | |||||
| cn = tw->m_tensor->comp_node(); | |||||
| valid = true; | |||||
| } else { | |||||
| CompNode cn1 = tw->m_tensor->comp_node(); | |||||
| if (cn1 != cn) { | |||||
| throw py::value_error(ssprintf("ambiguous device: %s vs %s", | |||||
| cn.to_string().c_str(), cn1.to_string().c_str())); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!valid) { | |||||
| mgb_assert(0, "expact at least 1 device"); | |||||
| } | |||||
| Py_DECREF(tuple); | |||||
| return cn; | |||||
| } | |||||
| // Returns the dtype that would result from performing an arithmetic | |||||
| // operation on the provided input tensors and scalars. | |||||
| PyObject* dtype_promotion(PyObject* self, PyObject*const* args, size_t nargs) { | |||||
| if (!nargs) { | |||||
| PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||||
| return nullptr; | |||||
| } | |||||
| try { | |||||
| PyArray_Descr* res = _dtype_promotion(args, nargs); | |||||
| return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); | |||||
| } catch (std::exception& e) { | |||||
| PyErr_SetString(PyExc_RuntimeError, e.what()); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| PyObject* get_device(PyObject* self, PyObject*const* args, size_t nargs) { | |||||
| if (!nargs) { | |||||
| PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||||
| return nullptr; | |||||
| } | |||||
| try { | |||||
| CompNode cn = _get_device(args, nargs); | |||||
| return py::cast(cn).release().ptr(); | |||||
| } catch (std::exception& e) { | |||||
| PyErr_SetString(PyExc_RuntimeError, e.what()); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| void init_tensor(py::module m) { | void init_tensor(py::module m) { | ||||
| interpreter_for_py = interpreter::Interpreter::inst().create_channel(); | interpreter_for_py = interpreter::Interpreter::inst().create_channel(); | ||||
| @@ -444,10 +642,19 @@ void init_tensor(py::module m) { | |||||
| .def(py::init<const TensorWrapper&>()) | .def(py::init<const TensorWrapper&>()) | ||||
| .def("__call__", &TensorWeakRef::operator()); | .def("__call__", &TensorWeakRef::operator()); | ||||
| static PyMethodDef apply_def{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr}; | |||||
| auto* apply_func = PyCFunction_NewEx(&apply_def, nullptr, nullptr); | |||||
| if (!apply_func) throw py::error_already_set(); | |||||
| py::setattr(m, "apply", apply_func); | |||||
| static PyMethodDef method_defs[] = { | |||||
| {"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr}, | |||||
| {"dtype_promotion", (PyCFunction)dtype_promotion, METH_FASTCALL, nullptr}, | |||||
| {"get_device", (PyCFunction)get_device, METH_FASTCALL, nullptr}, | |||||
| {nullptr, nullptr, 0, nullptr} | |||||
| }; | |||||
| for (auto&& def: method_defs) { | |||||
| if (def.ml_meth != nullptr) { | |||||
| auto* func = PyCFunction_NewEx(&def, nullptr, nullptr); | |||||
| if (!func) throw py::error_already_set(); | |||||
| py::setattr(m, def.ml_name, func); | |||||
| } | |||||
| } | |||||
| m.def("_set_swap_flag", | m.def("_set_swap_flag", | ||||
| [](bool flag) { interpreter_for_py->set_swap_flag(flag); }); | [](bool flag) { interpreter_for_py->set_swap_flag(flag); }); | ||||
| @@ -113,7 +113,7 @@ def test_quint8_typecvt(): | |||||
| data = np.random.random(shape).astype(np.float32) * 5 - 1 | data = np.random.random(shape).astype(np.float32) * 5 - 1 | ||||
| def typecvt(x, dt=None): | def typecvt(x, dt=None): | ||||
| (y,) = apply(ops.TypeCvt(dtype=dt), x) | |||||
| (y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x) | |||||
| return y | return y | ||||
| # convert to quint8 | # convert to quint8 | ||||
| @@ -194,7 +194,7 @@ def test_quint4_typecvt(): | |||||
| data = np.random.random(shape).astype(np.float32) * 5 - 1 | data = np.random.random(shape).astype(np.float32) * 5 - 1 | ||||
| def typecvt(x, dt=None): | def typecvt(x, dt=None): | ||||
| (y,) = apply(ops.TypeCvt(dtype=dt), x) | |||||
| (y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x) | |||||
| return y | return y | ||||
| # convert to quint4 | # convert to quint4 | ||||