| @@ -19,6 +19,7 @@ from .._imperative_rt.core2 import ( | |||||
| SymbolVar, | SymbolVar, | ||||
| Tensor, | Tensor, | ||||
| apply, | apply, | ||||
| astype_cpp, | |||||
| broadcast_cpp, | broadcast_cpp, | ||||
| dtype_promotion, | dtype_promotion, | ||||
| ) | ) | ||||
| @@ -27,14 +28,7 @@ from .._imperative_rt.core2 import reshape_cpp, squeeze_cpp, transpose_cpp | |||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from . import amp | from . import amp | ||||
| from .indexing import getitem, setitem | from .indexing import getitem, setitem | ||||
| from .utils import ( | |||||
| _normalize_axis, | |||||
| astensor1d, | |||||
| astype, | |||||
| cast_tensors, | |||||
| make_shape_tuple, | |||||
| subgraph, | |||||
| ) | |||||
| from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph | |||||
| _ElwMod = builtin.Elemwise.Mode | _ElwMod = builtin.Elemwise.Mode | ||||
| @@ -605,7 +599,7 @@ class ArrayMethodMixin(abc.ABC): | |||||
| r"""Returns a :class:`Tensor` with the same data and number of elements | r"""Returns a :class:`Tensor` with the same data and number of elements | ||||
| with the specified :attr:`~.Tensor.dtype`. | with the specified :attr:`~.Tensor.dtype`. | ||||
| """ | """ | ||||
| return astype(self, dtype) | |||||
| return astype_cpp(self, dtype) | |||||
| def reshape(self, *args): | def reshape(self, *args): | ||||
| r"""See :func:`~.reshape`.""" | r"""See :func:`~.reshape`.""" | ||||
| @@ -20,6 +20,9 @@ from .._imperative_rt.core2 import ( | |||||
| _get_convert_inputs, | _get_convert_inputs, | ||||
| _set_convert_inputs, | _set_convert_inputs, | ||||
| apply, | apply, | ||||
| astype_cpp, | |||||
| convert_inputs_cpp, | |||||
| convert_single_value_cpp, | |||||
| dtype_promotion, | dtype_promotion, | ||||
| get_device, | get_device, | ||||
| make_shape_tuple, | make_shape_tuple, | ||||
| @@ -55,53 +58,14 @@ def concatenate(inputs, axis=0, *, device=None): | |||||
| return result | return result | ||||
| def astype(x, dtype): | |||||
| dtype = np.dtype(dtype) | |||||
| if not is_dtype_equal(x.dtype, dtype): | |||||
| (x,) = apply(builtin.TypeCvt(dtype=dtype), x) | |||||
| return x | |||||
| def convert_single_value(v, *, dtype=None, device=None): | def convert_single_value(v, *, dtype=None, device=None): | ||||
| if isinstance(v, (Tensor, SymbolVar)): | |||||
| if not is_quantize(v.dtype): | |||||
| v = astype(v, dtype) | |||||
| else: | |||||
| v = Const(v, dtype, device, None) | |||||
| return v | |||||
| return convert_single_value_cpp(v, dtype, device) | |||||
| def convert_inputs(*args, device=None): | def convert_inputs(*args, device=None): | ||||
| if not _get_convert_inputs(): | if not _get_convert_inputs(): | ||||
| return args | return args | ||||
| dtype = dtype_promotion(args) | |||||
| if device is None: | |||||
| device = get_device(args) | |||||
| device = as_device(device) | |||||
| graph = None | |||||
| sym_type = None | |||||
| for a in args: | |||||
| if isinstance(a, SymbolVar): | |||||
| if graph is None: | |||||
| graph = a.var.graph | |||||
| sym_type = type(a) | |||||
| else: | |||||
| assert graph == a.var.graph | |||||
| args = list(args) | |||||
| if graph is not None: | |||||
| for i in range(len(args)): | |||||
| if not isinstance(args[i], SymbolVar): | |||||
| rst = make_const(graph, np.array(args[i]), device.to_c(), dtype) | |||||
| args[i] = sym_type(rst) | |||||
| def convert(value): | |||||
| if value is None: | |||||
| return value | |||||
| return convert_single_value(value, dtype=dtype, device=device.to_c()) | |||||
| return tuple(map(convert, args)) | |||||
| return convert_inputs_cpp(*args, device) | |||||
| def cast_tensors(*args, promote=False): | def cast_tensors(*args, promote=False): | ||||
| @@ -146,7 +110,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
| pass | pass | ||||
| except ValueError: | except ValueError: | ||||
| if dtype is not None and dtype != x.dtype: | if dtype is not None and dtype != x.dtype: | ||||
| x = astype(x, dtype) | |||||
| x = astype_cpp(x, dtype) | |||||
| if device is not None: | if device is not None: | ||||
| cn = as_device(device).to_c() | cn = as_device(device).to_c() | ||||
| (x,) = apply(builtin.Copy(comp_node=cn), x) | (x,) = apply(builtin.Copy(comp_node=cn), x) | ||||
| @@ -164,7 +128,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
| if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | ||||
| x = concatenate(x, device=device) if len(x) > 1 else x[0] | x = concatenate(x, device=device) if len(x) > 1 else x[0] | ||||
| if dtype is not None: | if dtype is not None: | ||||
| x = astype(x, dtype) | |||||
| x = astype_cpp(x, dtype) | |||||
| return x | return x | ||||
| x = Const(x, dtype, device, reference) | x = Const(x, dtype, device, reference) | ||||
| return x | return x | ||||
| @@ -30,7 +30,6 @@ from ..core.tensor import amp, megbrain_graph | |||||
| from ..core.tensor.array_method import _elwise_apply | from ..core.tensor.array_method import _elwise_apply | ||||
| from ..core.tensor.utils import ( | from ..core.tensor.utils import ( | ||||
| astensor1d, | astensor1d, | ||||
| astype, | |||||
| cast_tensors, | cast_tensors, | ||||
| convert_single_value, | convert_single_value, | ||||
| make_shape_tuple, | make_shape_tuple, | ||||
| @@ -170,6 +170,12 @@ struct _wrap { | |||||
| } // anonymous namespace | } // anonymous namespace | ||||
| namespace imperative::python { | |||||
| bool dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2) { | |||||
| return _is_dtype_equal(dt1, dt2); | |||||
| } | |||||
| } // namespace imperative::python | |||||
| #ifdef METH_FASTCALL | #ifdef METH_FASTCALL | ||||
| #define MGE_PY_INTERFACE(NAME, FUN) \ | #define MGE_PY_INTERFACE(NAME, FUN) \ | ||||
| { #NAME, (PyCFunction)_wrap < &(FUN)> ::impl, METH_FASTCALL, nullptr } | { #NAME, (PyCFunction)_wrap < &(FUN)> ::impl, METH_FASTCALL, nullptr } | ||||
| @@ -26,6 +26,11 @@ | |||||
| cb(BFloat16, npy_num_bfloat16()) | cb(BFloat16, npy_num_bfloat16()) | ||||
| namespace mgb { | namespace mgb { | ||||
| namespace imperative::python { | |||||
| bool dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2); | |||||
| } // namespace imperative::python | |||||
| //! numpy type num for intb1/2/4 type | //! numpy type num for intb1/2/4 type | ||||
| #define DEFINE_NPY_INTBX(n) int npy_num_intb##n(); | #define DEFINE_NPY_INTBX(n) int npy_num_intb##n(); | ||||
| FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX) | FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX) | ||||
| @@ -400,223 +400,6 @@ struct TensorWeakRef { | |||||
| int _use_cnt() { return wptr.use_count(); } | int _use_cnt() { return wptr.use_count(); } | ||||
| }; | }; | ||||
| /* ============== convert inputs ============== */ | |||||
| // map numpy.dtype.kind to priority | |||||
| inline uint8_t category_priority(char c) { | |||||
| switch (c) { | |||||
| case 'f': | |||||
| return 3; // floating-point | |||||
| case 'i': | |||||
| return 2; // signed integer | |||||
| case 'u': | |||||
| return 2; // unsigned integer | |||||
| case 'b': | |||||
| return 1; // boolean | |||||
| default: | |||||
| return 0; | |||||
| } | |||||
| } | |||||
| // Returns the maximum value of the priority of each type in the list `types`. | |||||
| uint8_t max_priority(SmallVector<PyArray_Descr*> types) { | |||||
| if (types.size() == 0) { | |||||
| return 0; | |||||
| } else { | |||||
| uint8_t max_p = 0; | |||||
| for (auto&& desc : types) { | |||||
| max_p = std::max(max_p, category_priority(desc->kind)); | |||||
| } | |||||
| return max_p; | |||||
| } | |||||
| } | |||||
| // Returns the data type with sufficient size to hold all types of | |||||
| // category `cat` in the list `types`. | |||||
| PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { | |||||
| // Return value: New reference | |||||
| SmallVector<PyArray_Descr*> used_types; | |||||
| for (auto&& desc : types) { | |||||
| auto&& v = category_priority(desc->kind); | |||||
| if (v == cat) { | |||||
| used_types.emplace_back(desc); | |||||
| } | |||||
| } | |||||
| mgb_assert(used_types.size() > 0, "size of used_types is 0"); | |||||
| PyArray_Descr* res = used_types[0]; | |||||
| Py_INCREF(res); | |||||
| for (size_t i = 1; i < used_types.size(); ++i) { | |||||
| PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res); | |||||
| Py_DECREF(res); | |||||
| res = tmp; | |||||
| } | |||||
| return res; | |||||
| } | |||||
| PyArray_Descr* scalar2dtype(PyObject* arg) { | |||||
| // Return value: New reference | |||||
| if (PyBool_Check(arg)) { | |||||
| auto&& descr = PyArray_DescrFromType(NPY_BOOL); | |||||
| return descr; | |||||
| } | |||||
| if (PyLong_CheckExact(arg)) { | |||||
| auto&& descr = PyArray_DescrFromType(NPY_INT32); | |||||
| return descr; | |||||
| } | |||||
| if (PyFloat_CheckExact(arg)) { | |||||
| auto&& descr = PyArray_DescrFromType(NPY_FLOAT32); | |||||
| return descr; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) { | |||||
| // Return value: New reference | |||||
| SmallVector<PyArray_Descr*> tensors; | |||||
| SmallVector<PyArray_Descr*> scalars; | |||||
| bool is_tuple = false; | |||||
| PyObject* tuple = nullptr; | |||||
| if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||||
| if (PyList_Check(args[0])) { | |||||
| tuple = PyList_AsTuple(args[0]); | |||||
| } else { | |||||
| tuple = args[0]; | |||||
| Py_INCREF(tuple); | |||||
| } | |||||
| nargs = PyTuple_Size(tuple); | |||||
| is_tuple = true; | |||||
| } | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||||
| if (handle == Py_None) | |||||
| continue; | |||||
| TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||||
| if (tw) { | |||||
| mgb::DType type = tw->m_tensor->dtype(); | |||||
| auto&& descr = npy::dtype_mgb2np_descr(type); | |||||
| Py_INCREF(descr.get()); | |||||
| tensors.emplace_back(descr.get()); | |||||
| } else { | |||||
| if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) { | |||||
| auto&& descr = PyArray_DescrFromObject(handle, nullptr); | |||||
| tensors.emplace_back(descr); | |||||
| continue; | |||||
| } | |||||
| if (py::isinstance<PySymbolVar>(py::handle(handle))) { | |||||
| auto var = py::handle(handle).cast<PySymbolVar*>(); | |||||
| mgb::DType type = var->m_node->dtype(); | |||||
| auto&& descr = npy::dtype_mgb2np_descr(type); | |||||
| Py_INCREF(descr.get()); | |||||
| tensors.emplace_back(descr.get()); | |||||
| continue; | |||||
| } | |||||
| PyArray_Descr* descr = scalar2dtype(handle); | |||||
| if (descr) { | |||||
| scalars.emplace_back(descr); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| } | |||||
| auto max_pri_scalars = max_priority(scalars); | |||||
| auto max_pri_tensors = max_priority(tensors); | |||||
| if (max_pri_scalars <= 0 && max_pri_tensors <= 0) { | |||||
| throw py::value_error("invalid input, no dtype avaliable"); | |||||
| } | |||||
| PyArray_Descr* res; | |||||
| if (max_pri_scalars > max_pri_tensors) { | |||||
| res = promote_types(scalars, max_pri_scalars); | |||||
| } else { | |||||
| res = promote_types(tensors, max_pri_tensors); | |||||
| } | |||||
| for (auto* p : tensors) { | |||||
| Py_DECREF(p); | |||||
| } | |||||
| for (auto* p : scalars) { | |||||
| Py_DECREF(p); | |||||
| } | |||||
| Py_XDECREF(tuple); | |||||
| return res; | |||||
| } | |||||
| CompNode _get_device(PyObject* const* args, size_t nargs) { | |||||
| bool is_tuple = false; | |||||
| PyObject* tuple = nullptr; | |||||
| if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||||
| if (PyList_Check(args[0])) { | |||||
| tuple = PyList_AsTuple(args[0]); | |||||
| } else { | |||||
| tuple = args[0]; | |||||
| Py_INCREF(tuple); | |||||
| } | |||||
| nargs = PyTuple_Size(tuple); | |||||
| is_tuple = true; | |||||
| } | |||||
| bool valid = false; | |||||
| CompNode cn; | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||||
| TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||||
| bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle)); | |||||
| if (tw || is_symvar) { | |||||
| if (!valid) { | |||||
| cn = tw ? tw->m_tensor->comp_node() | |||||
| : py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node(); | |||||
| valid = true; | |||||
| } else { | |||||
| CompNode cn1 = tw ? tw->m_tensor->comp_node() | |||||
| : py::handle(handle) | |||||
| .cast<PySymbolVar*>() | |||||
| ->m_node->comp_node(); | |||||
| if (cn1 != cn) { | |||||
| throw py::value_error(ssprintf( | |||||
| "ambiguous device: %s (from %s) vs %s (from %s)", | |||||
| cn.to_string().c_str(), cn.to_string_logical().c_str(), | |||||
| cn1.to_string().c_str(), cn1.to_string_logical().c_str())); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!valid) { | |||||
| return CompNode::load(get_default_device()); | |||||
| } | |||||
| Py_XDECREF(tuple); | |||||
| return cn; | |||||
| } | |||||
| // Returns the dtype that would result from performing an arithmetic | |||||
| // operation on the provided input tensors and scalars. | |||||
| PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| if (!nargs) { | |||||
| PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||||
| return nullptr; | |||||
| } | |||||
| try { | |||||
| PyArray_Descr* res = _dtype_promotion(args, nargs); | |||||
| return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); | |||||
| } | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
| } | |||||
| PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| if (!nargs) { | |||||
| PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||||
| return nullptr; | |||||
| } | |||||
| try { | |||||
| CompNode cn = _get_device(args, nargs); | |||||
| return py::cast(cn).release().ptr(); | |||||
| } | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
| } | |||||
| #ifdef METH_FASTCALL | #ifdef METH_FASTCALL | ||||
| #define MGE_PY_INTERFACE(NAME, FUNC) \ | #define MGE_PY_INTERFACE(NAME, FUNC) \ | ||||
| { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr } | { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr } | ||||
| @@ -640,6 +423,9 @@ WRAP_FUNC_PY35(transpose_cpp); | |||||
| WRAP_FUNC_PY35(broadcast_cpp); | WRAP_FUNC_PY35(broadcast_cpp); | ||||
| WRAP_FUNC_PY35(reshape_cpp); | WRAP_FUNC_PY35(reshape_cpp); | ||||
| WRAP_FUNC_PY35(Const); | WRAP_FUNC_PY35(Const); | ||||
| WRAP_FUNC_PY35(astype_cpp); | |||||
| WRAP_FUNC_PY35(convert_single_value_cpp); | |||||
| WRAP_FUNC_PY35(convert_inputs_cpp); | |||||
| #undef WRAP_FUNC_PY35 | #undef WRAP_FUNC_PY35 | ||||
| #define MGE_PY_INTERFACE(NAME, FUNC) \ | #define MGE_PY_INTERFACE(NAME, FUNC) \ | ||||
| { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | ||||
| @@ -779,6 +565,9 @@ void init_tensor(py::module m) { | |||||
| MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), | MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), | ||||
| MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), | MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), | ||||
| MGE_PY_INTERFACE(Const, Const), | MGE_PY_INTERFACE(Const, Const), | ||||
| MGE_PY_INTERFACE(astype_cpp, astype_cpp), | |||||
| MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp), | |||||
| MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp), | |||||
| {nullptr, nullptr, 0, nullptr}}; | {nullptr, nullptr, 0, nullptr}}; | ||||
| for (auto&& def : method_defs) { | for (auto&& def : method_defs) { | ||||
| if (def.ml_meth != nullptr) { | if (def.ml_meth != nullptr) { | ||||
| @@ -52,6 +52,223 @@ namespace views = ranges::views; | |||||
| namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
| /* ============== convert inputs ============== */ | |||||
| // map numpy.dtype.kind to priority | |||||
| inline uint8_t category_priority(char c) { | |||||
| switch (c) { | |||||
| case 'f': | |||||
| return 3; // floating-point | |||||
| case 'i': | |||||
| return 2; // signed integer | |||||
| case 'u': | |||||
| return 2; // unsigned integer | |||||
| case 'b': | |||||
| return 1; // boolean | |||||
| default: | |||||
| return 0; | |||||
| } | |||||
| } | |||||
| // Returns the maximum value of the priority of each type in the list `types`. | |||||
| uint8_t max_priority(SmallVector<PyArray_Descr*> types) { | |||||
| if (types.size() == 0) { | |||||
| return 0; | |||||
| } else { | |||||
| uint8_t max_p = 0; | |||||
| for (auto&& desc : types) { | |||||
| max_p = std::max(max_p, category_priority(desc->kind)); | |||||
| } | |||||
| return max_p; | |||||
| } | |||||
| } | |||||
| // Returns the data type with sufficient size to hold all types of | |||||
| // category `cat` in the list `types`. | |||||
| PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { | |||||
| // Return value: New reference | |||||
| SmallVector<PyArray_Descr*> used_types; | |||||
| for (auto&& desc : types) { | |||||
| auto&& v = category_priority(desc->kind); | |||||
| if (v == cat) { | |||||
| used_types.emplace_back(desc); | |||||
| } | |||||
| } | |||||
| mgb_assert(used_types.size() > 0, "size of used_types is 0"); | |||||
| PyArray_Descr* res = used_types[0]; | |||||
| Py_INCREF(res); | |||||
| for (size_t i = 1; i < used_types.size(); ++i) { | |||||
| PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res); | |||||
| Py_DECREF(res); | |||||
| res = tmp; | |||||
| } | |||||
| return res; | |||||
| } | |||||
| PyArray_Descr* scalar2dtype(PyObject* arg) { | |||||
| // Return value: New reference | |||||
| if (PyBool_Check(arg)) { | |||||
| auto&& descr = PyArray_DescrFromType(NPY_BOOL); | |||||
| return descr; | |||||
| } | |||||
| if (PyLong_CheckExact(arg)) { | |||||
| auto&& descr = PyArray_DescrFromType(NPY_INT32); | |||||
| return descr; | |||||
| } | |||||
| if (PyFloat_CheckExact(arg)) { | |||||
| auto&& descr = PyArray_DescrFromType(NPY_FLOAT32); | |||||
| return descr; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) { | |||||
| // Return value: New reference | |||||
| SmallVector<PyArray_Descr*> tensors; | |||||
| SmallVector<PyArray_Descr*> scalars; | |||||
| bool is_tuple = false; | |||||
| PyObject* tuple = nullptr; | |||||
| if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||||
| if (PyList_Check(args[0])) { | |||||
| tuple = PyList_AsTuple(args[0]); | |||||
| } else { | |||||
| tuple = args[0]; | |||||
| Py_INCREF(tuple); | |||||
| } | |||||
| nargs = PyTuple_Size(tuple); | |||||
| is_tuple = true; | |||||
| } | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||||
| if (handle == Py_None) | |||||
| continue; | |||||
| TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||||
| if (tw) { | |||||
| mgb::DType type = tw->m_tensor->dtype(); | |||||
| auto&& descr = npy::dtype_mgb2np_descr(type); | |||||
| Py_INCREF(descr.get()); | |||||
| tensors.emplace_back(descr.get()); | |||||
| } else { | |||||
| if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) { | |||||
| auto&& descr = PyArray_DescrFromObject(handle, nullptr); | |||||
| tensors.emplace_back(descr); | |||||
| continue; | |||||
| } | |||||
| if (py::isinstance<PySymbolVar>(py::handle(handle))) { | |||||
| auto var = py::handle(handle).cast<PySymbolVar*>(); | |||||
| mgb::DType type = var->m_node->dtype(); | |||||
| auto&& descr = npy::dtype_mgb2np_descr(type); | |||||
| Py_INCREF(descr.get()); | |||||
| tensors.emplace_back(descr.get()); | |||||
| continue; | |||||
| } | |||||
| PyArray_Descr* descr = scalar2dtype(handle); | |||||
| if (descr) { | |||||
| scalars.emplace_back(descr); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| } | |||||
| auto max_pri_scalars = max_priority(scalars); | |||||
| auto max_pri_tensors = max_priority(tensors); | |||||
| if (max_pri_scalars <= 0 && max_pri_tensors <= 0) { | |||||
| throw py::value_error("invalid input, no dtype avaliable"); | |||||
| } | |||||
| PyArray_Descr* res; | |||||
| if (max_pri_scalars > max_pri_tensors) { | |||||
| res = promote_types(scalars, max_pri_scalars); | |||||
| } else { | |||||
| res = promote_types(tensors, max_pri_tensors); | |||||
| } | |||||
| for (auto* p : tensors) { | |||||
| Py_DECREF(p); | |||||
| } | |||||
| for (auto* p : scalars) { | |||||
| Py_DECREF(p); | |||||
| } | |||||
| Py_XDECREF(tuple); | |||||
| return res; | |||||
| } | |||||
| CompNode _get_device(PyObject* const* args, size_t nargs) { | |||||
| bool is_tuple = false; | |||||
| PyObject* tuple = nullptr; | |||||
| if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||||
| if (PyList_Check(args[0])) { | |||||
| tuple = PyList_AsTuple(args[0]); | |||||
| } else { | |||||
| tuple = args[0]; | |||||
| Py_INCREF(tuple); | |||||
| } | |||||
| nargs = PyTuple_Size(tuple); | |||||
| is_tuple = true; | |||||
| } | |||||
| bool valid = false; | |||||
| CompNode cn; | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||||
| TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||||
| bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle)); | |||||
| if (tw || is_symvar) { | |||||
| if (!valid) { | |||||
| cn = tw ? tw->m_tensor->comp_node() | |||||
| : py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node(); | |||||
| valid = true; | |||||
| } else { | |||||
| CompNode cn1 = tw ? tw->m_tensor->comp_node() | |||||
| : py::handle(handle) | |||||
| .cast<PySymbolVar*>() | |||||
| ->m_node->comp_node(); | |||||
| if (cn1 != cn) { | |||||
| throw py::value_error(ssprintf( | |||||
| "ambiguous device: %s (from %s) vs %s (from %s)", | |||||
| cn.to_string().c_str(), cn.to_string_logical().c_str(), | |||||
| cn1.to_string().c_str(), cn1.to_string_logical().c_str())); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!valid) { | |||||
| return CompNode::load(get_default_device()); | |||||
| } | |||||
| Py_XDECREF(tuple); | |||||
| return cn; | |||||
| } | |||||
| // Returns the dtype that would result from performing an arithmetic | |||||
| // operation on the provided input tensors and scalars. | |||||
| PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| if (!nargs) { | |||||
| PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||||
| return nullptr; | |||||
| } | |||||
| try { | |||||
| PyArray_Descr* res = _dtype_promotion(args, nargs); | |||||
| return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); | |||||
| } | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
| } | |||||
| PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| if (!nargs) { | |||||
| PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||||
| return nullptr; | |||||
| } | |||||
| try { | |||||
| CompNode cn = _get_device(args, nargs); | |||||
| return py::cast(cn).release().ptr(); | |||||
| } | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
| } | |||||
| bool is_scalar(PyObject* tensor) { | bool is_scalar(PyObject* tensor) { | ||||
| if (py::isinstance<PySymbolVar>(py::handle(tensor))) { | if (py::isinstance<PySymbolVar>(py::handle(tensor))) { | ||||
| auto var = py::handle(tensor).cast<PySymbolVar*>(); | auto var = py::handle(tensor).cast<PySymbolVar*>(); | ||||
| @@ -147,7 +364,6 @@ py::object _Const( | |||||
| "dmap_callback"); | "dmap_callback"); | ||||
| if (dmap.ptr() != Py_None) { | if (dmap.ptr() != Py_None) { | ||||
| device_obj = dmap(device); | device_obj = dmap(device); | ||||
| py::print(device_obj); | |||||
| } else { | } else { | ||||
| device_obj = py::cast(CompNode::load(device.cast<std::string>())); | device_obj = py::cast(CompNode::load(device.cast<std::string>())); | ||||
| } | } | ||||
| @@ -1072,6 +1288,92 @@ py::object _reshape_cpp(py::handle inp_hdl, py::handle args) { | |||||
| return ret[0]; | return ret[0]; | ||||
| } | } | ||||
| mgb::DType _get_dtype(py::handle tensor) { | |||||
| if (auto tw = TensorWrapper::try_cast(tensor.ptr())) { | |||||
| return tw->m_tensor->dtype(); | |||||
| } else { | |||||
| auto var = tensor.cast<PySymbolVar*>(); | |||||
| return var->m_node->dtype(); | |||||
| } | |||||
| } | |||||
| py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) { | |||||
| PyArray_Descr* descr; | |||||
| if (!PyArray_DescrConverter(dtype_hdl.ptr(), &descr)) { | |||||
| throw py::value_error(ssprintf( | |||||
| "can not convert to numpy.dtype from %s", | |||||
| dtype_hdl.ptr()->ob_type->tp_name)); | |||||
| } | |||||
| PyArray_Descr* cur = npy::dtype_mgb2np_descr(_get_dtype(tensor)).get(); | |||||
| if (!dtype_equal(cur, descr)) { | |||||
| std::shared_ptr<OpDef> op = TypeCvt::make(npy::dtype_np2mgb_descr(descr)); | |||||
| py::object Op = py::cast(op); | |||||
| std::vector<PyObject*> p; | |||||
| p.resize(2); | |||||
| p[0] = Op.ptr(); | |||||
| p[1] = tensor.ptr(); | |||||
| py::tuple ret = | |||||
| py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | |||||
| return ret[0]; | |||||
| } else { | |||||
| return py::reinterpret_borrow<py::object>(tensor); | |||||
| } | |||||
| } | |||||
| py::object _convert_single_value_cpp( | |||||
| py::handle value, py::handle dtype, py::handle device) { | |||||
| if (is_tensor_or_symbolvar(value)) { | |||||
| if (_get_dtype(value).category() != DTypeCategory::QUANTIZED) { | |||||
| return _astype_cpp(value, dtype); | |||||
| } | |||||
| } else { | |||||
| return _Const(value, dtype, device, py::none()); | |||||
| } | |||||
| return py::reinterpret_borrow<py::object>(value); | |||||
| } | |||||
| py::object _convert_inputs_cpp( | |||||
| PyObject* const* args, size_t nargs, py::object dtype, py::object device) { | |||||
| ComputingGraph* graph = nullptr; | |||||
| py::handle typeobj; | |||||
| py::list lis; | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| py::handle h = py::handle(args[i]); | |||||
| lis.append(h); | |||||
| if (py::isinstance<PySymbolVar>(h)) { | |||||
| auto var = h.cast<PySymbolVar*>(); | |||||
| auto g = var->m_node->owner_graph(); | |||||
| if (!graph) { | |||||
| graph = g; | |||||
| typeobj = h.get_type(); | |||||
| } else { | |||||
| mgb_assert(graph == g); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (graph) { | |||||
| CompNode cn = device.cast<CompNode>(); | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| OperatorNodeConfig config(cn); | |||||
| auto hv = npy::np2tensor( | |||||
| lis[i].ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>()); | |||||
| if (py::isinstance<PySymbolVar>(lis[i])) { | |||||
| lis[i] = typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); | |||||
| } | |||||
| } | |||||
| } | |||||
| auto convert = [&](py::object value) { | |||||
| if (value.ptr() == Py_None) { | |||||
| return value; | |||||
| } | |||||
| return _convert_single_value_cpp(value, dtype, device); | |||||
| }; | |||||
| for (size_t i = 0; i < lis.size(); ++i) { | |||||
| lis[i] = convert(lis[i]); | |||||
| } | |||||
| return py::reinterpret_steal<py::tuple>(PyList_AsTuple(lis.ptr())); | |||||
| } | |||||
| PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | ||||
| try { | try { | ||||
| return _make_shape_tuple(py::handle(args[0])).release().ptr(); | return _make_shape_tuple(py::handle(args[0])).release().ptr(); | ||||
| @@ -1152,4 +1454,38 @@ PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | PYEXT17_TRANSLATE_EXC_RET(nullptr) | ||||
| } | } | ||||
| PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| try { | |||||
| return _astype_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); | |||||
| } | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
| } | |||||
| PyObject* convert_single_value_cpp( | |||||
| PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| try { | |||||
| return _convert_single_value_cpp( | |||||
| py::handle(args[0]), py::handle(args[1]), py::handle(args[2])) | |||||
| .release() | |||||
| .ptr(); | |||||
| } | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
| } | |||||
| PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| try { | |||||
| py::object dtype = py::reinterpret_steal<py::object>( | |||||
| dtype_promotion(self, args, nargs - 1)); | |||||
| py::object device; | |||||
| if (args[nargs - 1] == Py_None) { | |||||
| device = py::reinterpret_steal<py::object>( | |||||
| get_device(self, args, nargs - 1)); | |||||
| } else { | |||||
| device = py::reinterpret_borrow<py::object>(args[nargs - 1]); | |||||
| } | |||||
| return _convert_inputs_cpp(args, nargs - 1, dtype, device).release().ptr(); | |||||
| } | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
| } | |||||
| } // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||
| @@ -2,6 +2,10 @@ | |||||
| namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
| PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs); | |||||
| PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs); | |||||
| PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs); | ||||
| PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs); | ||||
| @@ -22,4 +26,10 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
| PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs); | ||||
| PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
| PyObject* convert_single_value_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
| PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
| } // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||