GitOrigin-RevId: baef3d348c
tags/v1.2.0
| @@ -8,7 +8,7 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from .._imperative_rt.core2 import Tensor | |||
| # from .._imperative_rt.core2 import Tensor | |||
| from ..tensor.core import OpBase, TensorBase, apply | |||
| @@ -19,5 +19,10 @@ class Const: | |||
| self.device = device | |||
| def __call__(self, *reference): | |||
| Wrapper = type(reference[0]) | |||
| return (Wrapper(self.value, self.dtype, self.device, True),) | |||
| from ...tensor import Tensor | |||
| device = self.device | |||
| if device is None: | |||
| device = reference[0].device | |||
| return (Tensor(self.value, self.dtype, self.device, True),) | |||
| @@ -13,6 +13,12 @@ import numpy as np | |||
| # normal dtype related | |||
| from .._imperative_rt import bfloat16, intb1, intb2, intb4 | |||
| from .._imperative_rt.common import ( | |||
| get_scale, | |||
| get_zero_point, | |||
| is_dtype_equal, | |||
| is_quantize, | |||
| ) | |||
| def is_lowbit(dtype): | |||
| @@ -42,41 +48,6 @@ _metadata_dict = { | |||
| } | |||
| def is_quantize(dtype): | |||
| return ( | |||
| hasattr(dtype, "metadata") | |||
| and dtype.metadata is not None | |||
| and "mgb_dtype" in dtype.metadata | |||
| ) | |||
| def get_scale(dtype): | |||
| assert is_quantize(dtype) | |||
| return dtype.metadata["mgb_dtype"]["scale"] | |||
| def get_zero_point(dtype): | |||
| assert is_quantize(dtype) | |||
| metadata = dtype.metadata["mgb_dtype"] | |||
| assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm") | |||
| return metadata["zero_point"] | |||
| def is_equal(dt0, dt1): | |||
| def _get_zero_point(dtype): | |||
| assert is_quantize(dtype) | |||
| metadata = dtype.metadata["mgb_dtype"] | |||
| return metadata.get("zero_point") | |||
| if is_quantize(dt0) and is_quantize(dt1): | |||
| return get_scale(dt0) == get_scale(dt1) and _get_zero_point( | |||
| dt0 | |||
| ) == _get_zero_point(dt1) | |||
| if not (is_quantize(dt0) or is_quantize(dt1)): | |||
| return dt0 == dt1 | |||
| return False | |||
| def _check_zero_point(zp: int, dtype_str: str): | |||
| qmin = _metadata_dict[dtype_str].qmin | |||
| qmax = _metadata_dict[dtype_str].qmax | |||
| @@ -151,9 +151,9 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
| def get_index(i): | |||
| if not isinstance(i, (Tensor)): | |||
| if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | |||
| (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | |||
| (i,) = Const(i, dtype=np.bool_, device=inp.device)() | |||
| else: | |||
| (i,) = Const(i, dtype=np.int32, device=inp.device)(inp) | |||
| (i,) = Const(i, dtype=np.int32, device=inp.device)() | |||
| return i | |||
| assert isinstance(i, Tensor) | |||
| if i.dtype != np.bool_: | |||
| @@ -197,7 +197,7 @@ def try_condtake(tensor, index): | |||
| ): | |||
| return [] | |||
| if isinstance(index, np.ndarray): | |||
| (index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) | |||
| (index,) = Const(index, dtype=np.bool_, device=tensor.device)() | |||
| assert isinstance(index, Tensor) | |||
| if not isinstance(tensor, Tensor): | |||
| raise TypeError("input must be a tensor") | |||
| @@ -217,9 +217,7 @@ def getitem(tensor, index): | |||
| if isinstance(v.shape, v.__class__): | |||
| break | |||
| if len(v.shape) > 0 and v.shape[0] == 0: | |||
| (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( | |||
| tensor | |||
| ) | |||
| (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)() | |||
| return empty_tensor | |||
| if use_subtensor: | |||
| op = builtin.Subtensor(items=items) | |||
| @@ -240,8 +238,7 @@ def setitem(tensor, index, value): | |||
| return tensor | |||
| tensor = tensor.reshape(-1) | |||
| if not isinstance(value, Tensor): | |||
| op = Const(value, dtype=tensor.dtype, device=tensor.device) | |||
| (value,) = op(tensor) | |||
| (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)() | |||
| tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | |||
| for v in tensors: | |||
| if len(v.shape) > 0 and v.shape[0] == 0: | |||
| @@ -11,10 +11,10 @@ from typing import Iterable, Union | |||
| import numpy as np | |||
| from .._imperative_rt.core2 import Tensor, apply | |||
| from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device | |||
| from ..ops import builtin | |||
| from ..ops.special import Const | |||
| from .dtype import is_equal, is_quantize | |||
| from .dtype import is_dtype_equal, is_quantize | |||
| from .megbrain_graph import VarNode | |||
| _enable_convert_inputs = True | |||
| @@ -37,94 +37,12 @@ def set_convert_inputs(flag): | |||
| return backup | |||
| def dtype_promotion(inputs): | |||
| """ | |||
| Returns the dtype that would result from performing an arithmetic | |||
| operation on the provided input tensors and scalars. | |||
| """ | |||
| # map numpy.dtype.kind to priority | |||
| category_priority = { | |||
| "f": 3, # floating-point | |||
| "i": 2, # signed integer | |||
| "u": 2, # unsigned integer | |||
| "b": 1, # boolean | |||
| } | |||
| def scalar2dtype(x): | |||
| """ | |||
| For scalar `x`, returns its corresponding type. A floating point scalar | |||
| has dtype 'float32'. An integral non-boolean scalar has dtype 'int32'. | |||
| A boolean scalar has dtype 'bool'. | |||
| """ | |||
| if isinstance(x, bool): | |||
| return np.bool_ | |||
| if isinstance(x, int): | |||
| return np.int32 | |||
| if isinstance(x, float): | |||
| return np.float32 | |||
| def promote_types(types, cat): | |||
| """ | |||
| Returns the data type with sufficient size to hold all types of | |||
| category `cat` in the list `types`. | |||
| """ | |||
| used_types = [ | |||
| i for i in types if category_priority.get(np.dtype(i).kind, 0) == cat | |||
| ] | |||
| assert len(used_types) > 0 | |||
| res = used_types[0] | |||
| for i in used_types: | |||
| res = np.promote_types(res, i) | |||
| return res | |||
| def max_priority(types): | |||
| """ | |||
| Returns the maximum value of the priority of each type in the list | |||
| `types`. | |||
| """ | |||
| if not types: | |||
| return 0 | |||
| else: | |||
| return max([category_priority.get(np.dtype(i).kind, 0) for i in types]) | |||
| scalars = [] | |||
| tensors = [] | |||
| for data in inputs: | |||
| if hasattr(data, "dtype"): | |||
| tensors.append(data.dtype) | |||
| elif isinstance(data, (float, int, bool)): | |||
| scalars.append(scalar2dtype(data)) | |||
| max_pri_scalars = max_priority(scalars) | |||
| max_pri_tensors = max_priority(tensors) | |||
| assert max_pri_scalars > 0 or max_pri_tensors > 0 | |||
| if max_pri_scalars > max_pri_tensors: | |||
| return promote_types(scalars, max_pri_scalars) | |||
| else: | |||
| return promote_types(tensors, max_pri_tensors) | |||
| def get_device(inputs): | |||
| device = None | |||
| for i in inputs: | |||
| if isinstance(i, (Tensor, VarNode)): | |||
| if device is None: | |||
| device = i.device | |||
| elif device != i.device: | |||
| raise ValueError("ambiguous device: {} vs {}".format(device, i.device)) | |||
| assert device is not None | |||
| return device | |||
| def concatenate(inputs, axis=0, *, device=None): | |||
| dtype = dtype_promotion(inputs) | |||
| device = get_device(inputs) | |||
| def convert(x): | |||
| return convert_single_value(x, inputs, dtype=dtype) | |||
| return convert_single_value(x, dtype=dtype, device=device) | |||
| inputs = tuple(map(convert, inputs)) | |||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) | |||
| @@ -133,7 +51,7 @@ def concatenate(inputs, axis=0, *, device=None): | |||
| def astype(x, dtype): | |||
| dtype = np.dtype(dtype) | |||
| if not is_equal(x.dtype, dtype): | |||
| if not is_dtype_equal(x.dtype, dtype): | |||
| isscalar = x.isscalar() | |||
| (x,) = apply(builtin.TypeCvt(dtype=dtype), x) | |||
| if isscalar: | |||
| @@ -141,13 +59,12 @@ def astype(x, dtype): | |||
| return x | |||
| def convert_single_value(v, inputs, *, dtype=None, device=None): | |||
| tensors = [i for i in inputs if isinstance(i, (Tensor, VarNode))] | |||
| assert len(tensors) > 0 | |||
| def convert_single_value(v, *, dtype=None, device=None): | |||
| if isinstance(v, (Tensor, VarNode)): | |||
| v = astype(v, v.dtype if is_quantize(v.dtype) else dtype) | |||
| if not is_quantize(v.dtype): | |||
| v = astype(v, dtype) | |||
| else: | |||
| (v,) = Const(v, dtype=dtype, device=device)(*tensors) | |||
| (v,) = Const(v, dtype=dtype, device=device)() | |||
| return v | |||
| @@ -161,7 +78,7 @@ def convert_inputs(*args: Tensor): | |||
| def convert(value): | |||
| if value is None: | |||
| return value | |||
| return convert_single_value(value, args, dtype=dtype, device=device) | |||
| return convert_single_value(value, dtype=dtype, device=device) | |||
| return tuple(map(convert, args)) | |||
| @@ -703,7 +703,7 @@ def topk( | |||
| op = builtin.TopK(mode=mode) | |||
| if not isinstance(k, Tensor): | |||
| (k,) = Const(k, dtype="int32", device=inp.device)(inp) | |||
| (k,) = Const(k, dtype="int32", device=inp.device)() | |||
| if len(inp.shape) == 1: | |||
| inp = inp.reshape(1, -1) | |||
| @@ -658,7 +658,7 @@ def batch_norm( | |||
| def make_full_if_none(x, value): | |||
| if x is None: | |||
| (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | |||
| (x,) = Const(value, dtype=inp.dtype, device=inp.device)() | |||
| shape = utils.astensor1d( | |||
| (1, C, 1, 1), inp, dtype="int32", device=inp.device | |||
| ) | |||
| @@ -1567,7 +1567,7 @@ def indexing_one_hot( | |||
| """ | |||
| assert isinstance(src, Tensor), "src must be of Tensor type" | |||
| op = builtin.IndexingOneHot(axis=axis) | |||
| index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device) | |||
| index = utils.convert_single_value(index, dtype="int32", device=src.device) | |||
| (result,) = apply(op, src, index) | |||
| if not keepdims: | |||
| result = squeeze(result, axis) | |||
| @@ -107,9 +107,7 @@ def full(shape, value, dtype="float32", device=None): | |||
| shape = (shape,) | |||
| if device is None: | |||
| device = get_default_device() | |||
| (x,) = Const(value, dtype=dtype, device=device)( | |||
| Tensor(value, dtype=dtype, device=device) | |||
| ) | |||
| (x,) = Const(value, dtype=dtype, device=device)() | |||
| if len(shape) == 0: # scalar | |||
| return x | |||
| return broadcast_to(x, shape) | |||
| @@ -265,7 +263,7 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | |||
| device = as_device(device) | |||
| def convert(x): | |||
| return convert_single_value(x, inps, dtype=dtype) | |||
| return convert_single_value(x, dtype=dtype, device=device) | |||
| inps = tuple(map(convert, inps)) | |||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) | |||
| @@ -37,8 +37,10 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
| else: | |||
| cn = CompNode(device) | |||
| else: | |||
| assert isinstance(device, CompNode) | |||
| cn = device | |||
| if isinstance(device, CompNode): | |||
| cn = device | |||
| else: | |||
| cn = device._cn | |||
| # import pdb; pdb.set_trace() | |||
| if isinstance(data, _Tensor): | |||
| @@ -179,4 +179,5 @@ void init_common(py::module m) { | |||
| init_npy_num_bfloat16(m); | |||
| init_npy_num_intbx(m); | |||
| init_dtypes(m); | |||
| } | |||
| @@ -158,7 +158,7 @@ void PyExceptionForward::throw_() { | |||
| /* ============== namespace npy ============== */ | |||
| namespace { | |||
| namespace npy { | |||
| int to_mgb_supported_dtype_raw(int dtype) { | |||
| if (dtype == NPY_INT64) | |||
| @@ -199,12 +199,6 @@ int dtype_mgb2np_raw(DType dtype) { | |||
| "can not convert dtype %s to numpy dtype", dtype.name())); | |||
| } | |||
| struct PyArrayDescrDeleter { | |||
| void operator()(PyArray_Descr* obj) { | |||
| Py_XDECREF(obj); | |||
| } | |||
| }; | |||
| //! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new | |||
| //! reference to the descriptor. | |||
| std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr( | |||
| @@ -585,9 +579,7 @@ void ndarray_shared_from_tensor_py_capsule_dtor(PyObject *cap) { | |||
| HostTensorNDRefHolder::free(static_cast<HostTensorNDRefHolder*>(ptr)); | |||
| } | |||
| } // anonymous namespace | |||
| PyObject* npy::ndarray_from_tensor( | |||
| PyObject* ndarray_from_tensor( | |||
| const HostTensorND &val, ShareType share_type) { | |||
| if (!val.layout().is_contiguous() && !val.shape().is_empty()) { | |||
| mgb_assert(share_type != ShareType::MUST_SHARE); | |||
| @@ -634,7 +626,7 @@ PyObject* npy::ndarray_from_tensor( | |||
| return ret; | |||
| } | |||
| HostTensorND npy::np2tensor(PyObject* obj, const Meth& meth, DType dtype) { | |||
| HostTensorND np2tensor(PyObject* obj, const Meth& meth, DType dtype) { | |||
| auto ret_full = np2tensor_try_borrow(obj, meth, dtype); | |||
| if (meth.must_borrow_) { | |||
| mgb_assert(ret_full.second, | |||
| @@ -645,7 +637,7 @@ HostTensorND npy::np2tensor(PyObject* obj, const Meth& meth, DType dtype) { | |||
| return ret_full.first; | |||
| } | |||
| PyObject* npy::dtype_mgb2np(mgb::DType dtype) { | |||
| PyObject* dtype_mgb2np(mgb::DType dtype) { | |||
| PYTHON_GIL; | |||
| // According to | |||
| // https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_TypeObjectFromType | |||
| @@ -668,7 +660,7 @@ PyObject* npy::dtype_mgb2np(mgb::DType dtype) { | |||
| return typeobj; | |||
| } | |||
| mgb::DType npy::dtype_np2mgb(PyObject *obj) { | |||
| mgb::DType dtype_np2mgb(PyObject *obj) { | |||
| mgb_assert(obj && obj != Py_None, | |||
| "can not convert null PyObject to numpy dtype"); | |||
| // see | |||
| @@ -686,7 +678,7 @@ mgb::DType npy::dtype_np2mgb(PyObject *obj) { | |||
| return result; | |||
| } | |||
| PyObject* npy::to_mgb_supported_dtype(PyObject* dtype) { | |||
| PyObject* to_mgb_supported_dtype(PyObject* dtype) { | |||
| PYTHON_GIL; | |||
| PyArray_Descr* descr; | |||
| @@ -702,7 +694,7 @@ PyObject* npy::to_mgb_supported_dtype(PyObject* dtype) { | |||
| return PyArray_TypeObjectFromType(type_num); | |||
| } | |||
| TensorShape npy::vec2shape(const std::vector<size_t> &vec) { | |||
| TensorShape vec2shape(const std::vector<size_t> &vec) { | |||
| TensorShape shape; | |||
| mgb_assert(vec.size() <= TensorShape::MAX_NDIM, | |||
| "dim too large: %zd (max %zd)", | |||
| @@ -718,3 +710,5 @@ TensorShape npy::vec2shape(const std::vector<size_t> &vec) { | |||
| mgb_assert(shape.ndim, "shape should not be empty"); | |||
| return shape; | |||
| } | |||
| } // namespace npy | |||
| @@ -11,7 +11,7 @@ | |||
| #pragma once | |||
| #include "megbrain/graph.h" | |||
| #include "megbrain/common.h" | |||
| #include "megbrain/utils/persistent_cache.h" | |||
| #include "megbrain/imperative/op_def.h" | |||
| @@ -26,6 +26,8 @@ | |||
| #include <pybind11/numpy.h> | |||
| #include <pybind11/functional.h> | |||
| #include "./numpy_dtypes.h" | |||
| pybind11::module submodule(pybind11::module parent, const char* name, const char* doc = nullptr); | |||
| pybind11::module rel_import(pybind11::str name, pybind11::module m, int level); | |||
| @@ -182,6 +184,18 @@ namespace npy { | |||
| //! convert raw vector to tensor shape | |||
| mgb::TensorShape vec2shape(const std::vector<size_t> &vec); | |||
| struct PyArrayDescrDeleter { | |||
| void operator()(PyArray_Descr* obj) { | |||
| Py_XDECREF(obj); | |||
| } | |||
| }; | |||
| //! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new | |||
| //! reference to the descriptor. | |||
| std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(mgb::DType dtype); | |||
| mgb::DType dtype_np2mgb_descr(PyArray_Descr* descr); | |||
| //! convert megbrain dtype to numpy dtype object; return new reference | |||
| PyObject* dtype_mgb2np(mgb::DType dtype); | |||
| @@ -0,0 +1,179 @@ | |||
| /** | |||
| * \file imperative/python/src/numpy_dtypes.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./numpy_dtypes.h" | |||
| #include "./helper.h" | |||
| #include "./pyext17.h" | |||
| #include "pybind11/pybind11.h" | |||
| #include <cstring> | |||
| namespace py = pybind11; | |||
| namespace mgb { | |||
| namespace { | |||
| inline bool _is_quantize(PyArray_Descr* dtype) { | |||
| static PyObject* PY_MGB_DTYPE_KEY = PyUnicode_FromString("mgb_dtype"); | |||
| return dtype->metadata && | |||
| PyDict_CheckExact(dtype->metadata) && | |||
| PyDict_Contains(dtype->metadata, PY_MGB_DTYPE_KEY) == 1; | |||
| } | |||
| PyObject* _get_mgb_dtype(PyArray_Descr* dtype) { | |||
| // Return value: New reference. | |||
| if (!_is_quantize(dtype)) { | |||
| throw py::type_error("expact quantize dtype"); | |||
| } | |||
| PyObject* ob = PyDict_GetItemString(dtype->metadata, "mgb_dtype"); | |||
| if (!PyDict_CheckExact(ob)) { | |||
| throw py::type_error("mgb_dtype is not dict"); | |||
| } | |||
| Py_INCREF(ob); | |||
| return ob; | |||
| } | |||
| double _get_scale(PyArray_Descr* dtype) { | |||
| PyObject* ob = _get_mgb_dtype(dtype); | |||
| PyObject* scale = PyDict_GetItemString(ob, "scale"); | |||
| if (!scale) { | |||
| Py_DECREF(ob); | |||
| throw py::key_error("scale"); | |||
| } | |||
| if (!PyFloat_Check(scale)) { | |||
| Py_DECREF(ob); | |||
| throw py::type_error("scale is not float"); | |||
| } | |||
| double ret = PyFloat_AsDouble(scale); | |||
| Py_DECREF(ob); | |||
| return ret; | |||
| } | |||
| long _get_zero_point(PyArray_Descr* dtype) { | |||
| PyObject* ob = _get_mgb_dtype(dtype); | |||
| PyObject* name = PyDict_GetItemString(ob, "name"); | |||
| if (!name) { | |||
| Py_DECREF(ob); | |||
| throw py::key_error("name"); | |||
| } | |||
| const char* s = PyUnicode_AsUTF8(name); | |||
| if (strcmp(s, "Quantized8Asymm") != 0 && strcmp(s, "Quantized4Asymm") != 0) { | |||
| Py_DECREF(ob); | |||
| throw py::value_error(ssprintf("expect name to be \"Quantized8Asymm\" or \"Quantized4Asymm\", got %s", s)); | |||
| } | |||
| PyObject* zp = PyDict_GetItemString(ob, "zero_point"); | |||
| if (!zp) { | |||
| Py_DECREF(ob); | |||
| throw py::key_error("zero_point"); | |||
| } | |||
| long ret = PyLong_AsLong(zp); | |||
| Py_DECREF(ob); | |||
| return ret; | |||
| } | |||
| bool _is_dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2) { | |||
| bool q1 = _is_quantize(dt1), | |||
| q2 = _is_quantize(dt2); | |||
| if (q1 && q2) { | |||
| if (_get_scale(dt1) != _get_scale(dt2)) { | |||
| return false; | |||
| } | |||
| PyObject* zp1 = PyDict_GetItemString( | |||
| PyDict_GetItemString(dt1->metadata, "mgb_dtype"), "zero_point"); | |||
| PyObject* zp2 = PyDict_GetItemString( | |||
| PyDict_GetItemString(dt2->metadata, "mgb_dtype"), "zero_point"); | |||
| if (!zp1 || !zp2) { | |||
| throw py::key_error("zero_point"); | |||
| } | |||
| return PyLong_AsLong(zp1) == PyLong_AsLong(zp2); | |||
| } | |||
| if (!q1 && !q2) { | |||
| return dt1->type_num == dt2->type_num; | |||
| } | |||
| return false; | |||
| } | |||
| template<auto f> | |||
| struct _wrap { | |||
| static constexpr size_t n_args = []() { | |||
| using F = decltype(f); | |||
| using T = PyArray_Descr*; | |||
| static_assert(std::is_pointer<F>::value); | |||
| if constexpr (std::is_invocable<F, T>::value) { | |||
| return 1; | |||
| } else if constexpr (std::is_invocable<F, T, T>::value) { | |||
| return 2; | |||
| } else { | |||
| static_assert(!std::is_same_v<F, F>, "unreachable"); | |||
| } | |||
| }(); | |||
| static PyObject* impl(PyObject* self, PyObject*const* args, size_t nargs) { | |||
| if (nargs != n_args) { | |||
| PyErr_Format(PyExc_ValueError, "expected %lu arguments", n_args); | |||
| return nullptr; | |||
| } | |||
| for (size_t i=0; i<nargs; ++i) { | |||
| if (args[i] == Py_None) { | |||
| PyErr_SetString(PyExc_ValueError, "can not convert null PyObject to numpy dtype"); | |||
| return nullptr; | |||
| } | |||
| } | |||
| try { | |||
| PyArray_Descr *dt1; | |||
| if(!PyArray_DescrConverter(args[0], &dt1)) { | |||
| throw ConversionError(ssprintf("can not convert to numpy.dtype from %s", | |||
| args[0]->ob_type->tp_name)); | |||
| } | |||
| if constexpr (n_args == 1) { | |||
| auto res = (*f)(dt1); | |||
| Py_DECREF(dt1); | |||
| return py::cast(res).release().ptr(); | |||
| } else { | |||
| PyArray_Descr *dt2; | |||
| if(!PyArray_DescrConverter(args[1], &dt2)) { | |||
| Py_DECREF(dt1); | |||
| throw ConversionError(ssprintf("can not convert to numpy.dtype from %s", | |||
| args[1]->ob_type->tp_name)); | |||
| } | |||
| auto&& res = (*f)(dt1, dt2); | |||
| Py_DECREF(dt1); | |||
| Py_DECREF(dt2); | |||
| return py::cast(res).release().ptr(); | |||
| } | |||
| } catch (std::exception& e) { | |||
| PyErr_SetString(PyExc_RuntimeError, e.what()); | |||
| return nullptr; | |||
| } | |||
| } | |||
| }; | |||
| } // anonymous namespace | |||
| void init_dtypes(py::module m) { | |||
| static PyMethodDef method_defs[] = { | |||
| {"is_quantize", (PyCFunction)_wrap<&_is_quantize>::impl, METH_FASTCALL, nullptr}, | |||
| {"get_scale", (PyCFunction)_wrap<&_get_scale>::impl, METH_FASTCALL, nullptr}, | |||
| {"get_zero_point", (PyCFunction)_wrap<&_get_zero_point>::impl, METH_FASTCALL, nullptr}, | |||
| {"is_dtype_equal", (PyCFunction)_wrap<&_is_dtype_equal>::impl, METH_FASTCALL, nullptr}, | |||
| {nullptr, nullptr, 0, nullptr} | |||
| }; | |||
| for (auto&& def: method_defs) { | |||
| if (def.ml_meth != nullptr) { | |||
| auto* func = PyCFunction_NewEx(&def, nullptr, nullptr); | |||
| if (!func) throw py::error_already_set(); | |||
| py::setattr(m, def.ml_name, func); | |||
| } | |||
| } | |||
| } | |||
| } // namespace mgb | |||
| @@ -36,6 +36,7 @@ namespace mgb { | |||
| int npy_num_intb##n(); | |||
| FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX) | |||
| #undef DEFINE_NPY_INTBX | |||
| void init_dtypes(pybind11::module m); | |||
| void init_npy_num_intbx(pybind11::module m); | |||
| //! numpy type num for bfloat16 type | |||
| @@ -9,16 +9,22 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/dtype.h" | |||
| #include "megbrain/common.h" | |||
| #include "./tensor.h" | |||
| #include "./grad.h" | |||
| #include "./trace.h" | |||
| #include "./common.h" | |||
| #include "./numpy_dtypes.h" | |||
| #include "./graph_rt.h" | |||
| #include "./helper.h" | |||
| #include <pybind11/numpy.h> | |||
| #include <pybind11/operators.h> | |||
| #include "./helper.h" | |||
| #include <unordered_map> | |||
| namespace py = pybind11; | |||
| namespace mgb::imperative::python { | |||
| @@ -413,6 +419,198 @@ struct TensorWeakRef { | |||
| } | |||
| }; | |||
| /* ============== convert inputs ============== */ | |||
| // map numpy.dtype.kind to priority | |||
| inline uint8_t category_priority(char c) { | |||
| switch (c) { | |||
| case 'f': return 3; // floating-point | |||
| case 'i': return 2; // signed integer | |||
| case 'u': return 2; // unsigned integer | |||
| case 'b': return 1; // boolean | |||
| default: return 0; | |||
| } | |||
| } | |||
| // Returns the maximum value of the priority of each type in the list `types`. | |||
| uint8_t max_priority(SmallVector<PyArray_Descr*> types) { | |||
| if (types.size() == 0) { | |||
| return 0; | |||
| } else { | |||
| uint8_t max_p = 0; | |||
| for (auto&& desc: types) { | |||
| max_p = std::max(max_p, category_priority(desc->kind)); | |||
| } | |||
| return max_p; | |||
| } | |||
| } | |||
| // Returns the data type with sufficient size to hold all types of | |||
| // category `cat` in the list `types`. | |||
| PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { | |||
| // Return value: New reference | |||
| SmallVector<PyArray_Descr*> used_types; | |||
| for (auto&& desc: types) { | |||
| auto&& v = category_priority(desc->kind); | |||
| if (v == cat) { | |||
| used_types.emplace_back(desc); | |||
| } | |||
| } | |||
| mgb_assert(used_types.size() > 0, "size of used_types is 0"); | |||
| PyArray_Descr* res = used_types[0]; | |||
| Py_INCREF(res); | |||
| for (size_t i = 1; i < used_types.size(); ++i) { | |||
| PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res); | |||
| Py_DECREF(res); | |||
| res = tmp; | |||
| } | |||
| return res; | |||
| } | |||
| PyArray_Descr* scalar2dtype(PyObject* arg) { | |||
| // Return value: New reference | |||
| if (PyBool_Check(arg)) { | |||
| auto&& descr = PyArray_DescrFromType(NPY_BOOL); | |||
| return descr; | |||
| } | |||
| if (PyLong_CheckExact(arg)) { | |||
| auto&& descr = PyArray_DescrFromType(NPY_INT32); | |||
| return descr; | |||
| } | |||
| if (PyFloat_CheckExact(arg)) { | |||
| auto&& descr = PyArray_DescrFromType(NPY_FLOAT32); | |||
| return descr; | |||
| } | |||
| return nullptr; | |||
| } | |||
| PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { | |||
| // Return value: New reference | |||
| SmallVector<PyArray_Descr*> tensors; | |||
| SmallVector<PyArray_Descr*> scalars; | |||
| bool is_tuple = false; | |||
| PyObject* tuple; | |||
| if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||
| if (PyList_Check(args[0])) { | |||
| tuple = PyList_AsTuple(args[0]); | |||
| } else { | |||
| tuple = args[0]; | |||
| Py_INCREF(tuple); | |||
| } | |||
| nargs = PyTuple_Size(tuple); | |||
| is_tuple = true; | |||
| } | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||
| if (handle == Py_None) continue; | |||
| TensorWrapper* tw = TensorWrapper::cast_safe(handle); | |||
| if (tw) { | |||
| mgb::DType type = tw->m_tensor->dtype(); | |||
| auto&& descr = npy::dtype_mgb2np_descr(type); | |||
| Py_INCREF(descr.get()); | |||
| tensors.emplace_back(descr.get()); | |||
| }else{ | |||
| if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) { | |||
| auto&& descr = PyArray_DescrFromObject(handle, nullptr); | |||
| tensors.emplace_back(descr); | |||
| continue; | |||
| } | |||
| PyArray_Descr* descr = scalar2dtype(handle); | |||
| if (descr) { | |||
| scalars.emplace_back(descr); | |||
| continue; | |||
| } | |||
| } | |||
| } | |||
| auto max_pri_scalars = max_priority(scalars); | |||
| auto max_pri_tensors = max_priority(tensors); | |||
| if (max_pri_scalars <= 0 && max_pri_tensors <= 0) { | |||
| throw py::value_error("invalid input, no dtype avaliable"); | |||
| } | |||
| PyArray_Descr* res; | |||
| if (max_pri_scalars > max_pri_tensors) { | |||
| res = promote_types(scalars, max_pri_scalars); | |||
| }else{ | |||
| res = promote_types(tensors, max_pri_tensors); | |||
| } | |||
| for (auto *p: tensors) { Py_DECREF(p); } | |||
| for (auto *p: scalars) { Py_DECREF(p); } | |||
| Py_DECREF(tuple); | |||
| return res; | |||
| } | |||
| CompNode _get_device(PyObject*const* args, size_t nargs) { | |||
| bool is_tuple = false; | |||
| PyObject* tuple; | |||
| if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||
| if (PyList_Check(args[0])) { | |||
| tuple = PyList_AsTuple(args[0]); | |||
| } else { | |||
| tuple = args[0]; | |||
| Py_INCREF(tuple); | |||
| } | |||
| nargs = PyTuple_Size(tuple); | |||
| is_tuple = true; | |||
| } | |||
| bool valid = false; | |||
| CompNode cn; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||
| TensorWrapper* tw = TensorWrapper::cast_safe(handle); | |||
| if (tw) { | |||
| if (!valid) { | |||
| cn = tw->m_tensor->comp_node(); | |||
| valid = true; | |||
| } else { | |||
| CompNode cn1 = tw->m_tensor->comp_node(); | |||
| if (cn1 != cn) { | |||
| throw py::value_error(ssprintf("ambiguous device: %s vs %s", | |||
| cn.to_string().c_str(), cn1.to_string().c_str())); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (!valid) { | |||
| mgb_assert(0, "expact at least 1 device"); | |||
| } | |||
| Py_DECREF(tuple); | |||
| return cn; | |||
| } | |||
| // Returns the dtype that would result from performing an arithmetic | |||
| // operation on the provided input tensors and scalars. | |||
| PyObject* dtype_promotion(PyObject* self, PyObject*const* args, size_t nargs) { | |||
| if (!nargs) { | |||
| PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||
| return nullptr; | |||
| } | |||
| try { | |||
| PyArray_Descr* res = _dtype_promotion(args, nargs); | |||
| return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); | |||
| } catch (std::exception& e) { | |||
| PyErr_SetString(PyExc_RuntimeError, e.what()); | |||
| return nullptr; | |||
| } | |||
| } | |||
| PyObject* get_device(PyObject* self, PyObject*const* args, size_t nargs) { | |||
| if (!nargs) { | |||
| PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||
| return nullptr; | |||
| } | |||
| try { | |||
| CompNode cn = _get_device(args, nargs); | |||
| return py::cast(cn).release().ptr(); | |||
| } catch (std::exception& e) { | |||
| PyErr_SetString(PyExc_RuntimeError, e.what()); | |||
| return nullptr; | |||
| } | |||
| } | |||
| void init_tensor(py::module m) { | |||
| interpreter_for_py = interpreter::Interpreter::inst().create_channel(); | |||
| @@ -444,10 +642,19 @@ void init_tensor(py::module m) { | |||
| .def(py::init<const TensorWrapper&>()) | |||
| .def("__call__", &TensorWeakRef::operator()); | |||
| static PyMethodDef apply_def{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr}; | |||
| auto* apply_func = PyCFunction_NewEx(&apply_def, nullptr, nullptr); | |||
| if (!apply_func) throw py::error_already_set(); | |||
| py::setattr(m, "apply", apply_func); | |||
| static PyMethodDef method_defs[] = { | |||
| {"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr}, | |||
| {"dtype_promotion", (PyCFunction)dtype_promotion, METH_FASTCALL, nullptr}, | |||
| {"get_device", (PyCFunction)get_device, METH_FASTCALL, nullptr}, | |||
| {nullptr, nullptr, 0, nullptr} | |||
| }; | |||
| for (auto&& def: method_defs) { | |||
| if (def.ml_meth != nullptr) { | |||
| auto* func = PyCFunction_NewEx(&def, nullptr, nullptr); | |||
| if (!func) throw py::error_already_set(); | |||
| py::setattr(m, def.ml_name, func); | |||
| } | |||
| } | |||
| m.def("_set_swap_flag", | |||
| [](bool flag) { interpreter_for_py->set_swap_flag(flag); }); | |||
| @@ -113,7 +113,7 @@ def test_quint8_typecvt(): | |||
| data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
| def typecvt(x, dt=None): | |||
| (y,) = apply(ops.TypeCvt(dtype=dt), x) | |||
| (y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x) | |||
| return y | |||
| # convert to quint8 | |||
| @@ -194,7 +194,7 @@ def test_quint4_typecvt(): | |||
| data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
| def typecvt(x, dt=None): | |||
| (y,) = apply(ops.TypeCvt(dtype=dt), x) | |||
| (y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x) | |||
| return y | |||
| # convert to quint4 | |||