| @@ -17,6 +17,7 @@ from .. import _config | |||||
| from .._imperative_rt.common import CompNode | from .._imperative_rt.common import CompNode | ||||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion | from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion | ||||
| from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | ||||
| from .._imperative_rt.core2 import squeeze_cpp | |||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from . import amp | from . import amp | ||||
| from .indexing import getitem, setitem | from .indexing import getitem, setitem | ||||
| @@ -448,26 +449,6 @@ def _logical_binary_elwise(mode, rev=False): | |||||
| return f | return f | ||||
| def _remove_axis(inp: Tensor, axis) -> Tensor: | |||||
| def get_axes(): | |||||
| if axis is None: | |||||
| shp = inp.shape | |||||
| return [i for i, s in enumerate(shp) if s == 1] | |||||
| try: | |||||
| return [int(axis)] | |||||
| except (TypeError, ValueError): | |||||
| pass | |||||
| return list(map(int, axis)) | |||||
| axis = get_axes() | |||||
| axis = _normalize_axis(inp.ndim, axis) | |||||
| axis = [a - i for i, a in enumerate(axis)] | |||||
| op = builtin.RemoveAxis(axis=axis) | |||||
| (result,) = apply(op, inp) | |||||
| return result | |||||
| def _reduce(mode): | def _reduce(mode): | ||||
| def f(self, axis=None, keepdims: bool = False): | def f(self, axis=None, keepdims: bool = False): | ||||
| data = self | data = self | ||||
| @@ -480,7 +461,7 @@ def _reduce(mode): | |||||
| op = builtin.Reduce(mode=mode, axis=ai) | op = builtin.Reduce(mode=mode, axis=ai) | ||||
| (data,) = apply(op, data) | (data,) = apply(op, data) | ||||
| if not keepdims: | if not keepdims: | ||||
| data = _remove_axis(data, ai) | |||||
| data = squeeze_cpp(data, ai) | |||||
| result = data | result = data | ||||
| else: | else: | ||||
| # builtin.Reduce already accept negtive axis | # builtin.Reduce already accept negtive axis | ||||
| @@ -488,7 +469,7 @@ def _reduce(mode): | |||||
| (result,) = apply(op, data) | (result,) = apply(op, data) | ||||
| if not keepdims: | if not keepdims: | ||||
| result = _remove_axis(result, axis) | |||||
| result = squeeze_cpp(result, axis) | |||||
| return result | return result | ||||
| return f | return f | ||||
| @@ -18,12 +18,13 @@ from ..core._imperative_rt.core2 import ( | |||||
| dtype_promotion, | dtype_promotion, | ||||
| expand_dims_cpp, | expand_dims_cpp, | ||||
| split_cpp, | split_cpp, | ||||
| squeeze_cpp, | |||||
| ) | ) | ||||
| from ..core._wrap import as_device | from ..core._wrap import as_device | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.builtin import Copy, Identity | from ..core.ops.builtin import Copy, Identity | ||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor.array_method import _broadcast, _remove_axis | |||||
| from ..core.tensor.array_method import _broadcast | |||||
| from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn | from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn | ||||
| from ..device import get_default_device | from ..device import get_default_device | ||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| @@ -996,7 +997,7 @@ def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Te | |||||
| (1, 1, 2) | (1, 1, 2) | ||||
| """ | """ | ||||
| return _remove_axis(inp, axis) | |||||
| return squeeze_cpp(inp, axis) | |||||
| def linspace( | def linspace( | ||||
| @@ -635,6 +635,7 @@ WRAP_FUNC_PY35(getitem_cpp); | |||||
| WRAP_FUNC_PY35(setitem_cpp); | WRAP_FUNC_PY35(setitem_cpp); | ||||
| WRAP_FUNC_PY35(split_cpp); | WRAP_FUNC_PY35(split_cpp); | ||||
| WRAP_FUNC_PY35(expand_dims_cpp); | WRAP_FUNC_PY35(expand_dims_cpp); | ||||
| WRAP_FUNC_PY35(squeeze_cpp); | |||||
| #undef WRAP_FUNC_PY35 | #undef WRAP_FUNC_PY35 | ||||
| #define MGE_PY_INTERFACE(NAME, FUNC) \ | #define MGE_PY_INTERFACE(NAME, FUNC) \ | ||||
| { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | ||||
| @@ -769,6 +770,7 @@ void init_tensor(py::module m) { | |||||
| MGE_PY_INTERFACE(setitem_cpp, setitem_cpp), | MGE_PY_INTERFACE(setitem_cpp, setitem_cpp), | ||||
| MGE_PY_INTERFACE(split_cpp, split_cpp), | MGE_PY_INTERFACE(split_cpp, split_cpp), | ||||
| MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp), | MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp), | ||||
| MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp), | |||||
| {nullptr, nullptr, 0, nullptr}}; | {nullptr, nullptr, 0, nullptr}}; | ||||
| for (auto&& def : method_defs) { | for (auto&& def : method_defs) { | ||||
| if (def.ml_meth != nullptr) { | if (def.ml_meth != nullptr) { | ||||
| @@ -683,17 +683,21 @@ py::object _split_cpp( | |||||
| return py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | return py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | ||||
| } | } | ||||
| py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { | |||||
| std::vector<int32_t> list2vector(py::handle li) { | |||||
| std::vector<int32_t> axis; | std::vector<int32_t> axis; | ||||
| if (is_py_sequence(axis_hdl.ptr())) { | |||||
| py::list tmp_list = | |||||
| py::reinterpret_steal<py::list>(PySequence_List(axis_hdl.ptr())); | |||||
| if (is_py_sequence(li.ptr())) { | |||||
| py::list tmp_list = py::reinterpret_steal<py::list>(PySequence_List(li.ptr())); | |||||
| for (size_t i = 0; i < tmp_list.size(); ++i) { | for (size_t i = 0; i < tmp_list.size(); ++i) { | ||||
| axis.push_back(tmp_list[i].attr("__int__")().cast<int32_t>()); | axis.push_back(tmp_list[i].attr("__int__")().cast<int32_t>()); | ||||
| } | } | ||||
| } else { | } else { | ||||
| axis.push_back(getattr(axis_hdl, "__int__")().cast<int>()); | |||||
| axis.push_back(getattr(li, "__int__")().cast<int32_t>()); | |||||
| } | } | ||||
| return axis; | |||||
| } | |||||
| py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { | |||||
| std::vector<int32_t> axis = list2vector(axis_hdl); | |||||
| bool unknown_ndim = true; | bool unknown_ndim = true; | ||||
| size_t ndim = axis.size(); | size_t ndim = axis.size(); | ||||
| if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) { | if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) { | ||||
| @@ -718,7 +722,7 @@ py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { | |||||
| "Does not support negative index when tensor's ndim is " | "Does not support negative index when tensor's ndim is " | ||||
| "unknown"); | "unknown"); | ||||
| } | } | ||||
| axis[i] += ndim; | |||||
| axis[i] += static_cast<int32_t>(ndim); | |||||
| } | } | ||||
| } | } | ||||
| if (!axis.size()) { | if (!axis.size()) { | ||||
| @@ -736,6 +740,59 @@ py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { | |||||
| return ret[0]; | return ret[0]; | ||||
| } | } | ||||
| py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) { | |||||
| std::vector<int32_t> axis; | |||||
| size_t ndim; | |||||
| if (axis_hdl.ptr() != Py_None) { | |||||
| axis = list2vector(axis_hdl); | |||||
| } | |||||
| if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) { | |||||
| auto&& shape = p->m_tensor->shape(); | |||||
| if (shape) { | |||||
| ndim = shape->ndim; | |||||
| if (axis_hdl.ptr() == Py_None) { | |||||
| for (size_t i = 0; i < shape->ndim; ++i) { | |||||
| if (shape->shape[i] == 1) { | |||||
| axis.push_back(i); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| auto&& var = inp_hdl.cast<PySymbolVar*>(); | |||||
| auto&& mgr = var->m_node->owner_graph()->static_infer_manager(); | |||||
| auto&& shape = mgr.infer_shape_fallible(var->m_node); | |||||
| if (shape) { | |||||
| ndim = shape->ndim; | |||||
| if (axis_hdl.ptr() == Py_None) { | |||||
| for (size_t i = 0; i < shape->ndim; ++i) { | |||||
| if (shape->shape[i] == 1) { | |||||
| axis.push_back(i); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| for (size_t i = 0; i < axis.size(); ++i) { | |||||
| if (axis[i] < 0) { | |||||
| axis[i] += static_cast<int32_t>(ndim); | |||||
| } | |||||
| } | |||||
| std::sort(axis.begin(), axis.end()); | |||||
| for (size_t i = 0; i < axis.size(); ++i) { | |||||
| axis[i] -= static_cast<int32_t>(i); | |||||
| } | |||||
| std::shared_ptr<OpDef> op = RemoveAxis::make(axis = axis); | |||||
| std::vector<PyObject*> p; | |||||
| p.resize(2); | |||||
| py::object Op = py::cast(op); | |||||
| p[0] = Op.ptr(); | |||||
| p[1] = inp_hdl.ptr(); | |||||
| py::tuple ret = | |||||
| py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | |||||
| return ret[0]; | |||||
| } | |||||
| PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | ||||
| try { | try { | ||||
| return _make_shape_tuple(py::handle(args[0])).release().ptr(); | return _make_shape_tuple(py::handle(args[0])).release().ptr(); | ||||
| @@ -778,4 +835,11 @@ PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | PYEXT17_TRANSLATE_EXC_RET(nullptr) | ||||
| } | } | ||||
| PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| try { | |||||
| return _squeeze_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); | |||||
| } | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
| } | |||||
| } // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||
| @@ -12,4 +12,6 @@ PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
| PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs); | ||||
| PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
| } // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||