|
|
|
@@ -16,7 +16,7 @@ |
|
|
|
|
|
|
|
#include <pybind11/numpy.h> |
|
|
|
#include <pybind11/operators.h> |
|
|
|
|
|
|
|
#include "./helper.h" |
|
|
|
namespace py = pybind11; |
|
|
|
|
|
|
|
namespace mgb::imperative::python { |
|
|
|
@@ -201,6 +201,24 @@ PyObject* TensorWrapper::detach() { |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
PyObject* TensorWrapper::_dev_tensor(){ |
|
|
|
auto dev_tensor = interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get()); |
|
|
|
return py::cast(dev_tensor).release().ptr(); |
|
|
|
} |
|
|
|
|
|
|
|
void TensorWrapper::_swap_out() { |
|
|
|
interpreter_for_py->swap_out(m_tensor->m_handle.get()); |
|
|
|
} |
|
|
|
|
|
|
|
void TensorWrapper::_swap_in() { |
|
|
|
interpreter_for_py->swap_in(m_tensor->m_handle.get()); |
|
|
|
} |
|
|
|
|
|
|
|
void TensorWrapper::_drop() { |
|
|
|
interpreter_for_py->drop(m_tensor->m_handle.get()); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
PyObject* TensorWrapper::isscalar() { |
|
|
|
if(m_tensor->m_flags & Tensor::Flags::SCALAR) { |
|
|
|
Py_RETURN_TRUE; |
|
|
|
@@ -240,6 +258,10 @@ void init_tensor(py::module m) { |
|
|
|
.def<&TensorWrapper::isscalar>("isscalar") |
|
|
|
.def<&TensorWrapper::setscalar>("setscalar") |
|
|
|
.def<&TensorWrapper::detach>("detach") |
|
|
|
.def<&TensorWrapper::_dev_tensor>("_dev_tensor") |
|
|
|
.def<&TensorWrapper::_swap_out>("_swap_out") |
|
|
|
.def<&TensorWrapper::_swap_in>("_swap_in") |
|
|
|
.def<&TensorWrapper::_drop>("_drop") |
|
|
|
.finalize(); |
|
|
|
if (!tensor_type) throw py::error_already_set(); |
|
|
|
py::setattr(m, "Tensor", tensor_type); |
|
|
|
@@ -253,6 +275,21 @@ void init_tensor(py::module m) { |
|
|
|
if (!apply_func) throw py::error_already_set(); |
|
|
|
py::setattr(m, "apply", apply_func); |
|
|
|
|
|
|
|
m.def("_set_swap_flag", |
|
|
|
[](bool flag) { interpreter_for_py->set_swap_flag(flag); }); |
|
|
|
m.def("_set_drop_flag", |
|
|
|
[](bool flag) { interpreter_for_py->set_drop_flag(flag); }); |
|
|
|
m.def("config_async_level", |
|
|
|
[](int level) { interpreter_for_py->config_async_level(level); }); |
|
|
|
m.def("get_async_level", |
|
|
|
[]() { return interpreter_for_py->get_async_level(); }); |
|
|
|
m.def("sync", |
|
|
|
[]() { |
|
|
|
interpreter_for_py->sync(); |
|
|
|
py_task_q.wait_all_task_finish(); |
|
|
|
}, |
|
|
|
py::call_guard<py::gil_scoped_release>()); |
|
|
|
|
|
|
|
py::handle grad_key_type = GradKeyWrapper::wrap_t::type() |
|
|
|
.def<&GradKeyWrapper::attach>("attach") |
|
|
|
.finalize(); |
|
|
|
|