|
|
|
@@ -683,17 +683,21 @@ py::object _split_cpp( |
|
|
|
return py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); |
|
|
|
} |
|
|
|
|
|
|
|
py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { |
|
|
|
std::vector<int32_t> list2vector(py::handle li) { |
|
|
|
std::vector<int32_t> axis; |
|
|
|
if (is_py_sequence(axis_hdl.ptr())) { |
|
|
|
py::list tmp_list = |
|
|
|
py::reinterpret_steal<py::list>(PySequence_List(axis_hdl.ptr())); |
|
|
|
if (is_py_sequence(li.ptr())) { |
|
|
|
py::list tmp_list = py::reinterpret_steal<py::list>(PySequence_List(li.ptr())); |
|
|
|
for (size_t i = 0; i < tmp_list.size(); ++i) { |
|
|
|
axis.push_back(tmp_list[i].attr("__int__")().cast<int32_t>()); |
|
|
|
} |
|
|
|
} else { |
|
|
|
axis.push_back(getattr(axis_hdl, "__int__")().cast<int>()); |
|
|
|
axis.push_back(getattr(li, "__int__")().cast<int32_t>()); |
|
|
|
} |
|
|
|
return axis; |
|
|
|
} |
|
|
|
|
|
|
|
py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { |
|
|
|
std::vector<int32_t> axis = list2vector(axis_hdl); |
|
|
|
bool unknown_ndim = true; |
|
|
|
size_t ndim = axis.size(); |
|
|
|
if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) { |
|
|
|
@@ -718,7 +722,7 @@ py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { |
|
|
|
"Does not support negative index when tensor's ndim is " |
|
|
|
"unknown"); |
|
|
|
} |
|
|
|
axis[i] += ndim; |
|
|
|
axis[i] += static_cast<int32_t>(ndim); |
|
|
|
} |
|
|
|
} |
|
|
|
if (!axis.size()) { |
|
|
|
@@ -736,6 +740,59 @@ py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { |
|
|
|
return ret[0]; |
|
|
|
} |
|
|
|
|
|
|
|
py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) { |
|
|
|
std::vector<int32_t> axis; |
|
|
|
size_t ndim; |
|
|
|
if (axis_hdl.ptr() != Py_None) { |
|
|
|
axis = list2vector(axis_hdl); |
|
|
|
} |
|
|
|
if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) { |
|
|
|
auto&& shape = p->m_tensor->shape(); |
|
|
|
if (shape) { |
|
|
|
ndim = shape->ndim; |
|
|
|
if (axis_hdl.ptr() == Py_None) { |
|
|
|
for (size_t i = 0; i < shape->ndim; ++i) { |
|
|
|
if (shape->shape[i] == 1) { |
|
|
|
axis.push_back(i); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto&& var = inp_hdl.cast<PySymbolVar*>(); |
|
|
|
auto&& mgr = var->m_node->owner_graph()->static_infer_manager(); |
|
|
|
auto&& shape = mgr.infer_shape_fallible(var->m_node); |
|
|
|
if (shape) { |
|
|
|
ndim = shape->ndim; |
|
|
|
if (axis_hdl.ptr() == Py_None) { |
|
|
|
for (size_t i = 0; i < shape->ndim; ++i) { |
|
|
|
if (shape->shape[i] == 1) { |
|
|
|
axis.push_back(i); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
for (size_t i = 0; i < axis.size(); ++i) { |
|
|
|
if (axis[i] < 0) { |
|
|
|
axis[i] += static_cast<int32_t>(ndim); |
|
|
|
} |
|
|
|
} |
|
|
|
std::sort(axis.begin(), axis.end()); |
|
|
|
for (size_t i = 0; i < axis.size(); ++i) { |
|
|
|
axis[i] -= static_cast<int32_t>(i); |
|
|
|
} |
|
|
|
std::shared_ptr<OpDef> op = RemoveAxis::make(axis = axis); |
|
|
|
std::vector<PyObject*> p; |
|
|
|
p.resize(2); |
|
|
|
py::object Op = py::cast(op); |
|
|
|
p[0] = Op.ptr(); |
|
|
|
p[1] = inp_hdl.ptr(); |
|
|
|
py::tuple ret = |
|
|
|
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); |
|
|
|
return ret[0]; |
|
|
|
} |
|
|
|
|
|
|
|
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { |
|
|
|
try { |
|
|
|
return _make_shape_tuple(py::handle(args[0])).release().ptr(); |
|
|
|
@@ -778,4 +835,11 @@ PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs) { |
|
|
|
PYEXT17_TRANSLATE_EXC_RET(nullptr) |
|
|
|
} |
|
|
|
|
|
|
|
PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs) { |
|
|
|
try { |
|
|
|
return _squeeze_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); |
|
|
|
} |
|
|
|
PYEXT17_TRANSLATE_EXC_RET(nullptr) |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace mgb::imperative::python |