|
|
|
@@ -683,6 +683,59 @@ py::object _split_cpp( |
|
|
|
return py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); |
|
|
|
} |
|
|
|
|
|
|
|
py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { |
|
|
|
std::vector<int32_t> axis; |
|
|
|
if (is_py_sequence(axis_hdl.ptr())) { |
|
|
|
py::list tmp_list = |
|
|
|
py::reinterpret_steal<py::list>(PySequence_List(axis_hdl.ptr())); |
|
|
|
for (size_t i = 0; i < tmp_list.size(); ++i) { |
|
|
|
axis.push_back(tmp_list[i].attr("__int__")().cast<int32_t>()); |
|
|
|
} |
|
|
|
} else { |
|
|
|
axis.push_back(getattr(axis_hdl, "__int__")().cast<int>()); |
|
|
|
} |
|
|
|
bool unknown_ndim = true; |
|
|
|
size_t ndim = axis.size(); |
|
|
|
if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) { |
|
|
|
auto&& shape = p->m_tensor->shape(); |
|
|
|
if (shape) { |
|
|
|
unknown_ndim = false; |
|
|
|
ndim += shape->ndim; |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto&& var = inp_hdl.cast<PySymbolVar*>(); |
|
|
|
auto&& mgr = var->m_node->owner_graph()->static_infer_manager(); |
|
|
|
auto&& shape = mgr.infer_shape_fallible(var->m_node); |
|
|
|
if (shape) { |
|
|
|
unknown_ndim = false; |
|
|
|
ndim += shape->ndim; |
|
|
|
} |
|
|
|
} |
|
|
|
for (size_t i = 0; i < axis.size(); ++i) { |
|
|
|
if (axis[i] < 0) { |
|
|
|
if (unknown_ndim) { |
|
|
|
throw py::index_error( |
|
|
|
"Does not support negative index when tensor's ndim is " |
|
|
|
"unknown"); |
|
|
|
} |
|
|
|
axis[i] += ndim; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!axis.size()) { |
|
|
|
throw py::index_error("axis could not be empty"); |
|
|
|
} |
|
|
|
std::sort(axis.begin(), axis.end()); |
|
|
|
std::shared_ptr<OpDef> op = AddAxis::make(axis = axis); |
|
|
|
std::vector<PyObject*> p; |
|
|
|
p.resize(2); |
|
|
|
py::object Op = py::cast(op); |
|
|
|
p[0] = Op.ptr(); |
|
|
|
p[1] = inp_hdl.ptr(); |
|
|
|
py::tuple ret = |
|
|
|
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); |
|
|
|
return ret[0]; |
|
|
|
} |
|
|
|
|
|
|
|
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { |
|
|
|
try { |
|
|
|
return _make_shape_tuple(py::handle(args[0])).release().ptr(); |
|
|
|
@@ -716,4 +769,13 @@ PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs) { |
|
|
|
PYEXT17_TRANSLATE_EXC_RET(nullptr) |
|
|
|
} |
|
|
|
|
|
|
|
PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs) { |
|
|
|
try { |
|
|
|
return _expand_dims_cpp(py::handle(args[0]), py::handle(args[1])) |
|
|
|
.release() |
|
|
|
.ptr(); |
|
|
|
} |
|
|
|
PYEXT17_TRANSLATE_EXC_RET(nullptr) |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace mgb::imperative::python |