| @@ -19,6 +19,7 @@ from .._imperative_rt.core2 import ( | |||
| SymbolVar, | |||
| Tensor, | |||
| apply, | |||
| astype_cpp, | |||
| broadcast_cpp, | |||
| dtype_promotion, | |||
| ) | |||
| @@ -27,14 +28,7 @@ from .._imperative_rt.core2 import reshape_cpp, squeeze_cpp, transpose_cpp | |||
| from ..ops import builtin | |||
| from . import amp | |||
| from .indexing import getitem, setitem | |||
| from .utils import ( | |||
| _normalize_axis, | |||
| astensor1d, | |||
| astype, | |||
| cast_tensors, | |||
| make_shape_tuple, | |||
| subgraph, | |||
| ) | |||
| from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph | |||
| _ElwMod = builtin.Elemwise.Mode | |||
| @@ -605,7 +599,7 @@ class ArrayMethodMixin(abc.ABC): | |||
| r"""Returns a :class:`Tensor` with the same data and number of elements | |||
| with the specified :attr:`~.Tensor.dtype`. | |||
| """ | |||
| return astype(self, dtype) | |||
| return astype_cpp(self, dtype) | |||
| def reshape(self, *args): | |||
| r"""See :func:`~.reshape`.""" | |||
| @@ -20,6 +20,9 @@ from .._imperative_rt.core2 import ( | |||
| _get_convert_inputs, | |||
| _set_convert_inputs, | |||
| apply, | |||
| astype_cpp, | |||
| convert_inputs_cpp, | |||
| convert_single_value_cpp, | |||
| dtype_promotion, | |||
| get_device, | |||
| make_shape_tuple, | |||
| @@ -55,53 +58,14 @@ def concatenate(inputs, axis=0, *, device=None): | |||
| return result | |||
| def astype(x, dtype): | |||
| dtype = np.dtype(dtype) | |||
| if not is_dtype_equal(x.dtype, dtype): | |||
| (x,) = apply(builtin.TypeCvt(dtype=dtype), x) | |||
| return x | |||
| def convert_single_value(v, *, dtype=None, device=None): | |||
| if isinstance(v, (Tensor, SymbolVar)): | |||
| if not is_quantize(v.dtype): | |||
| v = astype(v, dtype) | |||
| else: | |||
| v = Const(v, dtype, device, None) | |||
| return v | |||
| return convert_single_value_cpp(v, dtype, device) | |||
| def convert_inputs(*args, device=None): | |||
| if not _get_convert_inputs(): | |||
| return args | |||
| dtype = dtype_promotion(args) | |||
| if device is None: | |||
| device = get_device(args) | |||
| device = as_device(device) | |||
| graph = None | |||
| sym_type = None | |||
| for a in args: | |||
| if isinstance(a, SymbolVar): | |||
| if graph is None: | |||
| graph = a.var.graph | |||
| sym_type = type(a) | |||
| else: | |||
| assert graph == a.var.graph | |||
| args = list(args) | |||
| if graph is not None: | |||
| for i in range(len(args)): | |||
| if not isinstance(args[i], SymbolVar): | |||
| rst = make_const(graph, np.array(args[i]), device.to_c(), dtype) | |||
| args[i] = sym_type(rst) | |||
| def convert(value): | |||
| if value is None: | |||
| return value | |||
| return convert_single_value(value, dtype=dtype, device=device.to_c()) | |||
| return tuple(map(convert, args)) | |||
| return convert_inputs_cpp(*args, device) | |||
| def cast_tensors(*args, promote=False): | |||
| @@ -146,7 +110,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
| pass | |||
| except ValueError: | |||
| if dtype is not None and dtype != x.dtype: | |||
| x = astype(x, dtype) | |||
| x = astype_cpp(x, dtype) | |||
| if device is not None: | |||
| cn = as_device(device).to_c() | |||
| (x,) = apply(builtin.Copy(comp_node=cn), x) | |||
| @@ -164,7 +128,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
| if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | |||
| x = concatenate(x, device=device) if len(x) > 1 else x[0] | |||
| if dtype is not None: | |||
| x = astype(x, dtype) | |||
| x = astype_cpp(x, dtype) | |||
| return x | |||
| x = Const(x, dtype, device, reference) | |||
| return x | |||
| @@ -30,7 +30,6 @@ from ..core.tensor import amp, megbrain_graph | |||
| from ..core.tensor.array_method import _elwise_apply | |||
| from ..core.tensor.utils import ( | |||
| astensor1d, | |||
| astype, | |||
| cast_tensors, | |||
| convert_single_value, | |||
| make_shape_tuple, | |||
| @@ -170,6 +170,12 @@ struct _wrap { | |||
| } // anonymous namespace | |||
| namespace imperative::python { | |||
| bool dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2) { | |||
| return _is_dtype_equal(dt1, dt2); | |||
| } | |||
| } // namespace imperative::python | |||
| #ifdef METH_FASTCALL | |||
| #define MGE_PY_INTERFACE(NAME, FUN) \ | |||
| { #NAME, (PyCFunction)_wrap < &(FUN)> ::impl, METH_FASTCALL, nullptr } | |||
| @@ -26,6 +26,11 @@ | |||
| cb(BFloat16, npy_num_bfloat16()) | |||
| namespace mgb { | |||
| namespace imperative::python { | |||
| bool dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2); | |||
| } // namespace imperative::python | |||
| //! numpy type num for intb1/2/4 type | |||
| #define DEFINE_NPY_INTBX(n) int npy_num_intb##n(); | |||
| FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX) | |||
| @@ -400,223 +400,6 @@ struct TensorWeakRef { | |||
| int _use_cnt() { return wptr.use_count(); } | |||
| }; | |||
| /* ============== convert inputs ============== */ | |||
| // map numpy.dtype.kind to priority | |||
| inline uint8_t category_priority(char c) { | |||
| switch (c) { | |||
| case 'f': | |||
| return 3; // floating-point | |||
| case 'i': | |||
| return 2; // signed integer | |||
| case 'u': | |||
| return 2; // unsigned integer | |||
| case 'b': | |||
| return 1; // boolean | |||
| default: | |||
| return 0; | |||
| } | |||
| } | |||
| // Returns the maximum value of the priority of each type in the list `types`. | |||
| uint8_t max_priority(SmallVector<PyArray_Descr*> types) { | |||
| if (types.size() == 0) { | |||
| return 0; | |||
| } else { | |||
| uint8_t max_p = 0; | |||
| for (auto&& desc : types) { | |||
| max_p = std::max(max_p, category_priority(desc->kind)); | |||
| } | |||
| return max_p; | |||
| } | |||
| } | |||
| // Returns the data type with sufficient size to hold all types of | |||
| // category `cat` in the list `types`. | |||
| PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { | |||
| // Return value: New reference | |||
| SmallVector<PyArray_Descr*> used_types; | |||
| for (auto&& desc : types) { | |||
| auto&& v = category_priority(desc->kind); | |||
| if (v == cat) { | |||
| used_types.emplace_back(desc); | |||
| } | |||
| } | |||
| mgb_assert(used_types.size() > 0, "size of used_types is 0"); | |||
| PyArray_Descr* res = used_types[0]; | |||
| Py_INCREF(res); | |||
| for (size_t i = 1; i < used_types.size(); ++i) { | |||
| PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res); | |||
| Py_DECREF(res); | |||
| res = tmp; | |||
| } | |||
| return res; | |||
| } | |||
| PyArray_Descr* scalar2dtype(PyObject* arg) { | |||
| // Return value: New reference | |||
| if (PyBool_Check(arg)) { | |||
| auto&& descr = PyArray_DescrFromType(NPY_BOOL); | |||
| return descr; | |||
| } | |||
| if (PyLong_CheckExact(arg)) { | |||
| auto&& descr = PyArray_DescrFromType(NPY_INT32); | |||
| return descr; | |||
| } | |||
| if (PyFloat_CheckExact(arg)) { | |||
| auto&& descr = PyArray_DescrFromType(NPY_FLOAT32); | |||
| return descr; | |||
| } | |||
| return nullptr; | |||
| } | |||
| PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) { | |||
| // Return value: New reference | |||
| SmallVector<PyArray_Descr*> tensors; | |||
| SmallVector<PyArray_Descr*> scalars; | |||
| bool is_tuple = false; | |||
| PyObject* tuple = nullptr; | |||
| if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||
| if (PyList_Check(args[0])) { | |||
| tuple = PyList_AsTuple(args[0]); | |||
| } else { | |||
| tuple = args[0]; | |||
| Py_INCREF(tuple); | |||
| } | |||
| nargs = PyTuple_Size(tuple); | |||
| is_tuple = true; | |||
| } | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||
| if (handle == Py_None) | |||
| continue; | |||
| TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
| if (tw) { | |||
| mgb::DType type = tw->m_tensor->dtype(); | |||
| auto&& descr = npy::dtype_mgb2np_descr(type); | |||
| Py_INCREF(descr.get()); | |||
| tensors.emplace_back(descr.get()); | |||
| } else { | |||
| if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) { | |||
| auto&& descr = PyArray_DescrFromObject(handle, nullptr); | |||
| tensors.emplace_back(descr); | |||
| continue; | |||
| } | |||
| if (py::isinstance<PySymbolVar>(py::handle(handle))) { | |||
| auto var = py::handle(handle).cast<PySymbolVar*>(); | |||
| mgb::DType type = var->m_node->dtype(); | |||
| auto&& descr = npy::dtype_mgb2np_descr(type); | |||
| Py_INCREF(descr.get()); | |||
| tensors.emplace_back(descr.get()); | |||
| continue; | |||
| } | |||
| PyArray_Descr* descr = scalar2dtype(handle); | |||
| if (descr) { | |||
| scalars.emplace_back(descr); | |||
| continue; | |||
| } | |||
| } | |||
| } | |||
| auto max_pri_scalars = max_priority(scalars); | |||
| auto max_pri_tensors = max_priority(tensors); | |||
| if (max_pri_scalars <= 0 && max_pri_tensors <= 0) { | |||
| throw py::value_error("invalid input, no dtype avaliable"); | |||
| } | |||
| PyArray_Descr* res; | |||
| if (max_pri_scalars > max_pri_tensors) { | |||
| res = promote_types(scalars, max_pri_scalars); | |||
| } else { | |||
| res = promote_types(tensors, max_pri_tensors); | |||
| } | |||
| for (auto* p : tensors) { | |||
| Py_DECREF(p); | |||
| } | |||
| for (auto* p : scalars) { | |||
| Py_DECREF(p); | |||
| } | |||
| Py_XDECREF(tuple); | |||
| return res; | |||
| } | |||
| CompNode _get_device(PyObject* const* args, size_t nargs) { | |||
| bool is_tuple = false; | |||
| PyObject* tuple = nullptr; | |||
| if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||
| if (PyList_Check(args[0])) { | |||
| tuple = PyList_AsTuple(args[0]); | |||
| } else { | |||
| tuple = args[0]; | |||
| Py_INCREF(tuple); | |||
| } | |||
| nargs = PyTuple_Size(tuple); | |||
| is_tuple = true; | |||
| } | |||
| bool valid = false; | |||
| CompNode cn; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||
| TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
| bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle)); | |||
| if (tw || is_symvar) { | |||
| if (!valid) { | |||
| cn = tw ? tw->m_tensor->comp_node() | |||
| : py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node(); | |||
| valid = true; | |||
| } else { | |||
| CompNode cn1 = tw ? tw->m_tensor->comp_node() | |||
| : py::handle(handle) | |||
| .cast<PySymbolVar*>() | |||
| ->m_node->comp_node(); | |||
| if (cn1 != cn) { | |||
| throw py::value_error(ssprintf( | |||
| "ambiguous device: %s (from %s) vs %s (from %s)", | |||
| cn.to_string().c_str(), cn.to_string_logical().c_str(), | |||
| cn1.to_string().c_str(), cn1.to_string_logical().c_str())); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (!valid) { | |||
| return CompNode::load(get_default_device()); | |||
| } | |||
| Py_XDECREF(tuple); | |||
| return cn; | |||
| } | |||
| // Returns the dtype that would result from performing an arithmetic | |||
| // operation on the provided input tensors and scalars. | |||
| PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| if (!nargs) { | |||
| PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||
| return nullptr; | |||
| } | |||
| try { | |||
| PyArray_Descr* res = _dtype_promotion(args, nargs); | |||
| return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); | |||
| } | |||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
| } | |||
| PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| if (!nargs) { | |||
| PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||
| return nullptr; | |||
| } | |||
| try { | |||
| CompNode cn = _get_device(args, nargs); | |||
| return py::cast(cn).release().ptr(); | |||
| } | |||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
| } | |||
| #ifdef METH_FASTCALL | |||
| #define MGE_PY_INTERFACE(NAME, FUNC) \ | |||
| { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr } | |||
| @@ -640,6 +423,9 @@ WRAP_FUNC_PY35(transpose_cpp); | |||
| WRAP_FUNC_PY35(broadcast_cpp); | |||
| WRAP_FUNC_PY35(reshape_cpp); | |||
| WRAP_FUNC_PY35(Const); | |||
| WRAP_FUNC_PY35(astype_cpp); | |||
| WRAP_FUNC_PY35(convert_single_value_cpp); | |||
| WRAP_FUNC_PY35(convert_inputs_cpp); | |||
| #undef WRAP_FUNC_PY35 | |||
| #define MGE_PY_INTERFACE(NAME, FUNC) \ | |||
| { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | |||
| @@ -779,6 +565,9 @@ void init_tensor(py::module m) { | |||
| MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), | |||
| MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), | |||
| MGE_PY_INTERFACE(Const, Const), | |||
| MGE_PY_INTERFACE(astype_cpp, astype_cpp), | |||
| MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp), | |||
| MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp), | |||
| {nullptr, nullptr, 0, nullptr}}; | |||
| for (auto&& def : method_defs) { | |||
| if (def.ml_meth != nullptr) { | |||
| @@ -52,6 +52,223 @@ namespace views = ranges::views; | |||
| namespace mgb::imperative::python { | |||
| /* ============== convert inputs ============== */ | |||
| // map numpy.dtype.kind to priority | |||
| inline uint8_t category_priority(char c) { | |||
| switch (c) { | |||
| case 'f': | |||
| return 3; // floating-point | |||
| case 'i': | |||
| return 2; // signed integer | |||
| case 'u': | |||
| return 2; // unsigned integer | |||
| case 'b': | |||
| return 1; // boolean | |||
| default: | |||
| return 0; | |||
| } | |||
| } | |||
| // Returns the maximum value of the priority of each type in the list `types`. | |||
| uint8_t max_priority(SmallVector<PyArray_Descr*> types) { | |||
| if (types.size() == 0) { | |||
| return 0; | |||
| } else { | |||
| uint8_t max_p = 0; | |||
| for (auto&& desc : types) { | |||
| max_p = std::max(max_p, category_priority(desc->kind)); | |||
| } | |||
| return max_p; | |||
| } | |||
| } | |||
| // Returns the data type with sufficient size to hold all types of | |||
| // category `cat` in the list `types`. | |||
| PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { | |||
| // Return value: New reference | |||
| SmallVector<PyArray_Descr*> used_types; | |||
| for (auto&& desc : types) { | |||
| auto&& v = category_priority(desc->kind); | |||
| if (v == cat) { | |||
| used_types.emplace_back(desc); | |||
| } | |||
| } | |||
| mgb_assert(used_types.size() > 0, "size of used_types is 0"); | |||
| PyArray_Descr* res = used_types[0]; | |||
| Py_INCREF(res); | |||
| for (size_t i = 1; i < used_types.size(); ++i) { | |||
| PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res); | |||
| Py_DECREF(res); | |||
| res = tmp; | |||
| } | |||
| return res; | |||
| } | |||
| PyArray_Descr* scalar2dtype(PyObject* arg) { | |||
| // Return value: New reference | |||
| if (PyBool_Check(arg)) { | |||
| auto&& descr = PyArray_DescrFromType(NPY_BOOL); | |||
| return descr; | |||
| } | |||
| if (PyLong_CheckExact(arg)) { | |||
| auto&& descr = PyArray_DescrFromType(NPY_INT32); | |||
| return descr; | |||
| } | |||
| if (PyFloat_CheckExact(arg)) { | |||
| auto&& descr = PyArray_DescrFromType(NPY_FLOAT32); | |||
| return descr; | |||
| } | |||
| return nullptr; | |||
| } | |||
| PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) { | |||
| // Return value: New reference | |||
| SmallVector<PyArray_Descr*> tensors; | |||
| SmallVector<PyArray_Descr*> scalars; | |||
| bool is_tuple = false; | |||
| PyObject* tuple = nullptr; | |||
| if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||
| if (PyList_Check(args[0])) { | |||
| tuple = PyList_AsTuple(args[0]); | |||
| } else { | |||
| tuple = args[0]; | |||
| Py_INCREF(tuple); | |||
| } | |||
| nargs = PyTuple_Size(tuple); | |||
| is_tuple = true; | |||
| } | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||
| if (handle == Py_None) | |||
| continue; | |||
| TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
| if (tw) { | |||
| mgb::DType type = tw->m_tensor->dtype(); | |||
| auto&& descr = npy::dtype_mgb2np_descr(type); | |||
| Py_INCREF(descr.get()); | |||
| tensors.emplace_back(descr.get()); | |||
| } else { | |||
| if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) { | |||
| auto&& descr = PyArray_DescrFromObject(handle, nullptr); | |||
| tensors.emplace_back(descr); | |||
| continue; | |||
| } | |||
| if (py::isinstance<PySymbolVar>(py::handle(handle))) { | |||
| auto var = py::handle(handle).cast<PySymbolVar*>(); | |||
| mgb::DType type = var->m_node->dtype(); | |||
| auto&& descr = npy::dtype_mgb2np_descr(type); | |||
| Py_INCREF(descr.get()); | |||
| tensors.emplace_back(descr.get()); | |||
| continue; | |||
| } | |||
| PyArray_Descr* descr = scalar2dtype(handle); | |||
| if (descr) { | |||
| scalars.emplace_back(descr); | |||
| continue; | |||
| } | |||
| } | |||
| } | |||
| auto max_pri_scalars = max_priority(scalars); | |||
| auto max_pri_tensors = max_priority(tensors); | |||
| if (max_pri_scalars <= 0 && max_pri_tensors <= 0) { | |||
| throw py::value_error("invalid input, no dtype avaliable"); | |||
| } | |||
| PyArray_Descr* res; | |||
| if (max_pri_scalars > max_pri_tensors) { | |||
| res = promote_types(scalars, max_pri_scalars); | |||
| } else { | |||
| res = promote_types(tensors, max_pri_tensors); | |||
| } | |||
| for (auto* p : tensors) { | |||
| Py_DECREF(p); | |||
| } | |||
| for (auto* p : scalars) { | |||
| Py_DECREF(p); | |||
| } | |||
| Py_XDECREF(tuple); | |||
| return res; | |||
| } | |||
| CompNode _get_device(PyObject* const* args, size_t nargs) { | |||
| bool is_tuple = false; | |||
| PyObject* tuple = nullptr; | |||
| if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||
| if (PyList_Check(args[0])) { | |||
| tuple = PyList_AsTuple(args[0]); | |||
| } else { | |||
| tuple = args[0]; | |||
| Py_INCREF(tuple); | |||
| } | |||
| nargs = PyTuple_Size(tuple); | |||
| is_tuple = true; | |||
| } | |||
| bool valid = false; | |||
| CompNode cn; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||
| TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
| bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle)); | |||
| if (tw || is_symvar) { | |||
| if (!valid) { | |||
| cn = tw ? tw->m_tensor->comp_node() | |||
| : py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node(); | |||
| valid = true; | |||
| } else { | |||
| CompNode cn1 = tw ? tw->m_tensor->comp_node() | |||
| : py::handle(handle) | |||
| .cast<PySymbolVar*>() | |||
| ->m_node->comp_node(); | |||
| if (cn1 != cn) { | |||
| throw py::value_error(ssprintf( | |||
| "ambiguous device: %s (from %s) vs %s (from %s)", | |||
| cn.to_string().c_str(), cn.to_string_logical().c_str(), | |||
| cn1.to_string().c_str(), cn1.to_string_logical().c_str())); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (!valid) { | |||
| return CompNode::load(get_default_device()); | |||
| } | |||
| Py_XDECREF(tuple); | |||
| return cn; | |||
| } | |||
| // Returns the dtype that would result from performing an arithmetic | |||
| // operation on the provided input tensors and scalars. | |||
| PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| if (!nargs) { | |||
| PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||
| return nullptr; | |||
| } | |||
| try { | |||
| PyArray_Descr* res = _dtype_promotion(args, nargs); | |||
| return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); | |||
| } | |||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
| } | |||
| PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| if (!nargs) { | |||
| PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||
| return nullptr; | |||
| } | |||
| try { | |||
| CompNode cn = _get_device(args, nargs); | |||
| return py::cast(cn).release().ptr(); | |||
| } | |||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
| } | |||
| bool is_scalar(PyObject* tensor) { | |||
| if (py::isinstance<PySymbolVar>(py::handle(tensor))) { | |||
| auto var = py::handle(tensor).cast<PySymbolVar*>(); | |||
| @@ -147,7 +364,6 @@ py::object _Const( | |||
| "dmap_callback"); | |||
| if (dmap.ptr() != Py_None) { | |||
| device_obj = dmap(device); | |||
| py::print(device_obj); | |||
| } else { | |||
| device_obj = py::cast(CompNode::load(device.cast<std::string>())); | |||
| } | |||
| @@ -1072,6 +1288,92 @@ py::object _reshape_cpp(py::handle inp_hdl, py::handle args) { | |||
| return ret[0]; | |||
| } | |||
| mgb::DType _get_dtype(py::handle tensor) { | |||
| if (auto tw = TensorWrapper::try_cast(tensor.ptr())) { | |||
| return tw->m_tensor->dtype(); | |||
| } else { | |||
| auto var = tensor.cast<PySymbolVar*>(); | |||
| return var->m_node->dtype(); | |||
| } | |||
| } | |||
| py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) { | |||
| PyArray_Descr* descr; | |||
| if (!PyArray_DescrConverter(dtype_hdl.ptr(), &descr)) { | |||
| throw py::value_error(ssprintf( | |||
| "can not convert to numpy.dtype from %s", | |||
| dtype_hdl.ptr()->ob_type->tp_name)); | |||
| } | |||
| PyArray_Descr* cur = npy::dtype_mgb2np_descr(_get_dtype(tensor)).get(); | |||
| if (!dtype_equal(cur, descr)) { | |||
| std::shared_ptr<OpDef> op = TypeCvt::make(npy::dtype_np2mgb_descr(descr)); | |||
| py::object Op = py::cast(op); | |||
| std::vector<PyObject*> p; | |||
| p.resize(2); | |||
| p[0] = Op.ptr(); | |||
| p[1] = tensor.ptr(); | |||
| py::tuple ret = | |||
| py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | |||
| return ret[0]; | |||
| } else { | |||
| return py::reinterpret_borrow<py::object>(tensor); | |||
| } | |||
| } | |||
| py::object _convert_single_value_cpp( | |||
| py::handle value, py::handle dtype, py::handle device) { | |||
| if (is_tensor_or_symbolvar(value)) { | |||
| if (_get_dtype(value).category() != DTypeCategory::QUANTIZED) { | |||
| return _astype_cpp(value, dtype); | |||
| } | |||
| } else { | |||
| return _Const(value, dtype, device, py::none()); | |||
| } | |||
| return py::reinterpret_borrow<py::object>(value); | |||
| } | |||
| py::object _convert_inputs_cpp( | |||
| PyObject* const* args, size_t nargs, py::object dtype, py::object device) { | |||
| ComputingGraph* graph = nullptr; | |||
| py::handle typeobj; | |||
| py::list lis; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| py::handle h = py::handle(args[i]); | |||
| lis.append(h); | |||
| if (py::isinstance<PySymbolVar>(h)) { | |||
| auto var = h.cast<PySymbolVar*>(); | |||
| auto g = var->m_node->owner_graph(); | |||
| if (!graph) { | |||
| graph = g; | |||
| typeobj = h.get_type(); | |||
| } else { | |||
| mgb_assert(graph == g); | |||
| } | |||
| } | |||
| } | |||
| if (graph) { | |||
| CompNode cn = device.cast<CompNode>(); | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| OperatorNodeConfig config(cn); | |||
| auto hv = npy::np2tensor( | |||
| lis[i].ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>()); | |||
| if (py::isinstance<PySymbolVar>(lis[i])) { | |||
| lis[i] = typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); | |||
| } | |||
| } | |||
| } | |||
| auto convert = [&](py::object value) { | |||
| if (value.ptr() == Py_None) { | |||
| return value; | |||
| } | |||
| return _convert_single_value_cpp(value, dtype, device); | |||
| }; | |||
| for (size_t i = 0; i < lis.size(); ++i) { | |||
| lis[i] = convert(lis[i]); | |||
| } | |||
| return py::reinterpret_steal<py::tuple>(PyList_AsTuple(lis.ptr())); | |||
| } | |||
| PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| try { | |||
| return _make_shape_tuple(py::handle(args[0])).release().ptr(); | |||
| @@ -1152,4 +1454,38 @@ PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
| } | |||
| PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| try { | |||
| return _astype_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); | |||
| } | |||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
| } | |||
| PyObject* convert_single_value_cpp( | |||
| PyObject* self, PyObject* const* args, size_t nargs) { | |||
| try { | |||
| return _convert_single_value_cpp( | |||
| py::handle(args[0]), py::handle(args[1]), py::handle(args[2])) | |||
| .release() | |||
| .ptr(); | |||
| } | |||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
| } | |||
| PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| try { | |||
| py::object dtype = py::reinterpret_steal<py::object>( | |||
| dtype_promotion(self, args, nargs - 1)); | |||
| py::object device; | |||
| if (args[nargs - 1] == Py_None) { | |||
| device = py::reinterpret_steal<py::object>( | |||
| get_device(self, args, nargs - 1)); | |||
| } else { | |||
| device = py::reinterpret_borrow<py::object>(args[nargs - 1]); | |||
| } | |||
| return _convert_inputs_cpp(args, nargs - 1, dtype, device).release().ptr(); | |||
| } | |||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
| } | |||
| } // namespace mgb::imperative::python | |||
| @@ -2,6 +2,10 @@ | |||
| namespace mgb::imperative::python { | |||
| PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs); | |||
| PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs); | |||
| PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs); | |||
| PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
| @@ -22,4 +26,10 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
| PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs); | |||
| PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
| PyObject* convert_single_value_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
| PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
| } // namespace mgb::imperative::python | |||