| @@ -17,7 +17,7 @@ from .. import _config | |||||
| from .._imperative_rt.common import CompNode | from .._imperative_rt.common import CompNode | ||||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion | from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion | ||||
| from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | ||||
| from .._imperative_rt.core2 import squeeze_cpp | |||||
| from .._imperative_rt.core2 import squeeze_cpp, transpose_cpp | |||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from . import amp | from . import amp | ||||
| from .indexing import getitem, setitem | from .indexing import getitem, setitem | ||||
| @@ -331,12 +331,6 @@ def _matmul( | |||||
| return result | return result | ||||
| def _transpose(data, axes): | |||||
| op = builtin.Dimshuffle(axes) | |||||
| (result,) = apply(op, data) | |||||
| return result | |||||
| def _broadcast(inp, shape): | def _broadcast(inp, shape): | ||||
| auto_infer = False | auto_infer = False | ||||
| if isinstance(shape, (list, tuple)): | if isinstance(shape, (list, tuple)): | ||||
| @@ -681,15 +675,7 @@ class ArrayMethodMixin(abc.ABC): | |||||
| def transpose(self, *args): | def transpose(self, *args): | ||||
| r"""See :func:`~.transpose`.""" | r"""See :func:`~.transpose`.""" | ||||
| if self.ndim == 0: | |||||
| assert ( | |||||
| len(args) == 0 | |||||
| ), "transpose for scalar does not accept additional args" | |||||
| ret = self.to(self.device) | |||||
| return ret | |||||
| if not args: | |||||
| args = range(self.ndim)[::-1] | |||||
| return _transpose(self, _expand_args(args)) | |||||
| return transpose_cpp(self, args) | |||||
| def flatten(self): | def flatten(self): | ||||
| r"""See :func:`~.flatten`.""" | r"""See :func:`~.flatten`.""" | ||||
| @@ -865,7 +865,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||||
| [[1 0] | [[1 0] | ||||
| [1 0]] | [1 0]] | ||||
| """ | """ | ||||
| return inp.transpose(list(-1 if _ == "x" else _ for _ in pattern)) | |||||
| return inp.transpose(pattern) | |||||
| def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: | def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: | ||||
| @@ -636,6 +636,7 @@ WRAP_FUNC_PY35(setitem_cpp); | |||||
| WRAP_FUNC_PY35(split_cpp); | WRAP_FUNC_PY35(split_cpp); | ||||
| WRAP_FUNC_PY35(expand_dims_cpp); | WRAP_FUNC_PY35(expand_dims_cpp); | ||||
| WRAP_FUNC_PY35(squeeze_cpp); | WRAP_FUNC_PY35(squeeze_cpp); | ||||
| WRAP_FUNC_PY35(transpose_cpp); | |||||
| #undef WRAP_FUNC_PY35 | #undef WRAP_FUNC_PY35 | ||||
| #define MGE_PY_INTERFACE(NAME, FUNC) \ | #define MGE_PY_INTERFACE(NAME, FUNC) \ | ||||
| { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | ||||
| @@ -771,6 +772,7 @@ void init_tensor(py::module m) { | |||||
| MGE_PY_INTERFACE(split_cpp, split_cpp), | MGE_PY_INTERFACE(split_cpp, split_cpp), | ||||
| MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp), | MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp), | ||||
| MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp), | MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp), | ||||
| MGE_PY_INTERFACE(transpose_cpp, transpose_cpp), | |||||
| {nullptr, nullptr, 0, nullptr}}; | {nullptr, nullptr, 0, nullptr}}; | ||||
| for (auto&& def : method_defs) { | for (auto&& def : method_defs) { | ||||
| if (def.ml_meth != nullptr) { | if (def.ml_meth != nullptr) { | ||||
| @@ -793,6 +793,57 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) { | |||||
| return ret[0]; | return ret[0]; | ||||
| } | } | ||||
| size_t fast_ndim(py::handle tensor) { | |||||
| if (auto p = TensorWrapper::try_cast(tensor.ptr())) { | |||||
| return p->m_tensor->shape()->ndim; | |||||
| } | |||||
| return getattr(tensor, "ndim").cast<size_t>(); | |||||
| } | |||||
| py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||||
| py::tuple args_tup = py::reinterpret_borrow<py::tuple>(args.ptr()); | |||||
| if (fast_ndim(inp_hdl) == 0) { | |||||
| if (args_tup.size() != 0) { | |||||
| throw py::index_error( | |||||
| "transpose for scalar does not accept additional args"); | |||||
| } | |||||
| return getattr(inp_hdl, "to")(getattr(inp_hdl, "device")); | |||||
| } | |||||
| std::vector<int32_t> pattern; | |||||
| if (!args_tup.size()) { | |||||
| size_t ndim = getattr(inp_hdl, "ndim").cast<size_t>(); | |||||
| for (size_t i = 0; i < ndim; ++i) { | |||||
| pattern.push_back(ndim - i - 1); | |||||
| } | |||||
| } else { | |||||
| py::list lis; | |||||
| if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) || | |||||
| is_tensor_or_symbolvar(args_tup[0].ptr()))) { | |||||
| lis = py::reinterpret_steal<py::list>(PySequence_List(args_tup[0].ptr())); | |||||
| } else { | |||||
| lis = py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr())); | |||||
| } | |||||
| for (size_t i = 0; i < lis.size(); ++i) { | |||||
| if (PyLong_Check(lis[i].ptr())) { | |||||
| pattern.push_back(lis[i].cast<int32_t>()); | |||||
| } else { | |||||
| if (lis[i].cast<std::string>() == "x") { | |||||
| pattern.push_back(-1); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| std::shared_ptr<OpDef> op = Dimshuffle::make(pattern); | |||||
| std::vector<PyObject*> p; | |||||
| p.resize(2); | |||||
| py::object Op = py::cast(op); | |||||
| p[0] = Op.ptr(); | |||||
| p[1] = inp_hdl.ptr(); | |||||
| py::tuple ret = | |||||
| py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | |||||
| return ret[0]; | |||||
| } | |||||
| PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | ||||
| try { | try { | ||||
| return _make_shape_tuple(py::handle(args[0])).release().ptr(); | return _make_shape_tuple(py::handle(args[0])).release().ptr(); | ||||
| @@ -842,4 +893,11 @@ PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | PYEXT17_TRANSLATE_EXC_RET(nullptr) | ||||
| } | } | ||||
| PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| try { | |||||
| return _transpose_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); | |||||
| } | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
| } | |||||
| } // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||
| @@ -14,4 +14,6 @@ PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
| PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs); | ||||
| PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
| } // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||