| @@ -72,7 +72,6 @@ if sys.platform == "win32": | |||||
| kernel32.SetErrorMode(old_error_mode) | kernel32.SetErrorMode(old_error_mode) | ||||
| from .core._imperative_rt.core2 import full_sync as _full_sync | from .core._imperative_rt.core2 import full_sync as _full_sync | ||||
| from .core._imperative_rt.core2 import release_trace_apply_func | |||||
| from .core._imperative_rt.core2 import sync as _sync | from .core._imperative_rt.core2 import sync as _sync | ||||
| from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func | from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func | ||||
| from .device import * | from .device import * | ||||
| @@ -92,9 +91,7 @@ _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer() | |||||
| _persistent_cache_impl_ins.reg() | _persistent_cache_impl_ins.reg() | ||||
| atexit.register(_full_sync) | atexit.register(_full_sync) | ||||
| atexit.register(release_trace_apply_func) | |||||
| del release_trace_apply_func | |||||
| del _set_fork_exec_path_for_timed_func | del _set_fork_exec_path_for_timed_func | ||||
| del _persistent_cache_impl_ins | del _persistent_cache_impl_ins | ||||
| @@ -34,22 +34,15 @@ namespace mgb::imperative::python { | |||||
| std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | ||||
| py::object cpp_apply_with_tracing, cpp_apply_const_with_tracing, | |||||
| cpp_apply_compiled_mode, cpp_apply_const_compiled_mode; | |||||
| PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing, | |||||
| *cpp_apply_compiled_mode, *cpp_apply_const_compiled_mode; | |||||
| py::object cpp_apply_backward_varnode; | |||||
| PyObject *cpp_apply_backward_varnode; | |||||
| void release_trace_apply_func(){ | |||||
| cpp_apply_with_tracing.release(); | |||||
| cpp_apply_const_with_tracing.release(); | |||||
| cpp_apply_compiled_mode.release(); | |||||
| cpp_apply_const_compiled_mode.release(); | |||||
| cpp_apply_backward_varnode.release(); | |||||
| } | |||||
| #define REGISTE_APPLY_FUNC(mode) \ | #define REGISTE_APPLY_FUNC(mode) \ | ||||
| void set_##mode(py::object pyf) { \ | void set_##mode(py::object pyf) { \ | ||||
| mode = pybind11::reinterpret_steal<py::object>(pyf); \ | |||||
| mode = pyf.ptr(); \ | |||||
| } | } | ||||
| REGISTE_APPLY_FUNC(cpp_apply_with_tracing) | REGISTE_APPLY_FUNC(cpp_apply_with_tracing) | ||||
| @@ -242,14 +235,15 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
| // const op | // const op | ||||
| if (is_const && is_tracing) { | if (is_const && is_tracing) { | ||||
| py::object pyf; | |||||
| PyObject *pyf; | |||||
| if (is_compiled) { | if (is_compiled) { | ||||
| pyf = cpp_apply_const_compiled_mode; | pyf = cpp_apply_const_compiled_mode; | ||||
| } else { | } else { | ||||
| pyf = cpp_apply_const_with_tracing; | pyf = cpp_apply_const_with_tracing; | ||||
| } | } | ||||
| auto ret = pyf(*tup); | |||||
| auto ret = py::reinterpret_steal<py::object>( | |||||
| PyObject_Call(pyf, tup.ptr(), nullptr)); | |||||
| auto py_ret = py::reinterpret_borrow<py::list>(ret); | auto py_ret = py::reinterpret_borrow<py::list>(ret); | ||||
| if (auto* t = try_cast(py_ret[0].ptr())) { | if (auto* t = try_cast(py_ret[0].ptr())) { | ||||
| m_tensor = t->m_tensor; | m_tensor = t->m_tensor; | ||||
| @@ -744,8 +738,6 @@ void init_tensor(py::module m) { | |||||
| }, | }, | ||||
| py::call_guard<py::gil_scoped_release>()); | py::call_guard<py::gil_scoped_release>()); | ||||
| m.def("release_trace_apply_func", &release_trace_apply_func); | |||||
| py::handle grad_key_type = GradKeyWrapper::wrap_t::type() | py::handle grad_key_type = GradKeyWrapper::wrap_t::type() | ||||
| .def<&GradKeyWrapper::attach>("attach") | .def<&GradKeyWrapper::attach>("attach") | ||||
| .def<&GradKeyWrapper::is_attached_to>("is_attached_to") | .def<&GradKeyWrapper::is_attached_to>("is_attached_to") | ||||
| @@ -253,8 +253,8 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors) | |||||
| void init_tensor(pybind11::module); | void init_tensor(pybind11::module); | ||||
| extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode; | |||||
| extern pybind11::object cpp_apply_backward_varnode; | |||||
| extern PyObject *cpp_apply_with_tracing, *cpp_apply_compiled_mode; | |||||
| extern PyObject *cpp_apply_backward_varnode; | |||||
| } // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "./trace.h" | #include "./trace.h" | ||||
| @@ -23,12 +24,13 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||||
| if (ctx.backward) { | if (ctx.backward) { | ||||
| // reach here when symbolic=True or compiled=True | // reach here when symbolic=True or compiled=True | ||||
| // call megbrain_graph.py apply(BackwardGraph, *args) | // call megbrain_graph.py apply(BackwardGraph, *args) | ||||
| auto args = py::tuple(ctx.nargs); | |||||
| auto args = py::tuple(ctx.nargs + 1); | |||||
| args[0] = py::cast(ctx.op); | |||||
| for (size_t i = 0; i < ctx.nargs; i++) { | for (size_t i = 0; i < ctx.nargs; i++) { | ||||
| args[i] = py::cast(ctx.args[i]->m_var); | |||||
| args[i + 1] = py::cast(ctx.args[i]->m_var); | |||||
| } | } | ||||
| py::object ret = cpp_apply_backward_varnode(py::cast(ctx.op), *args); | |||||
| py::object ret = py::reinterpret_steal<py::object>( | |||||
| PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr)); | |||||
| if (!ret) { | if (!ret) { | ||||
| throw py::value_error("invalid py object call"); | throw py::value_error("invalid py object call"); | ||||
| } | } | ||||
| @@ -36,13 +38,13 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||||
| // assumption: python function always returns PyList | // assumption: python function always returns PyList | ||||
| auto tup = py::reinterpret_borrow<py::list>(ret); | auto tup = py::reinterpret_borrow<py::list>(ret); | ||||
| for (auto i = 0; i < tup.size(); i++) { | for (auto i = 0; i < tup.size(); i++) { | ||||
| auto pitem = tup[i].cast<cg::VarNode *>(); | |||||
| auto pitem = tup[i].cast<cg::VarNode*>(); | |||||
| outputs.emplace_back(std::make_shared<Tensor>(pitem)); | outputs.emplace_back(std::make_shared<Tensor>(pitem)); | ||||
| } | } | ||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| py::object pyf; | |||||
| PyObject* pyf; | |||||
| if (is_compiled) { | if (is_compiled) { | ||||
| // run apply in compiled mode, step 2, 3, etc | // run apply in compiled mode, step 2, 3, etc | ||||
| pyf = cpp_apply_compiled_mode; | pyf = cpp_apply_compiled_mode; | ||||
| @@ -51,11 +53,15 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||||
| pyf = cpp_apply_with_tracing; | pyf = cpp_apply_with_tracing; | ||||
| } | } | ||||
| auto args = py::tuple(ctx.nargs); | |||||
| auto args = py::tuple(ctx.nargs + 1); | |||||
| args[0] = py::cast(ctx.op); | |||||
| for (size_t i = 0; i < ctx.nargs; i++) { | for (size_t i = 0; i < ctx.nargs; i++) { | ||||
| args[i] = TensorWrapper::make(std::move(std::shared_ptr<Tensor>(ctx.args[i]))).release(); | |||||
| args[i + 1] = TensorWrapper::make( | |||||
| std::move(std::shared_ptr<Tensor>(ctx.args[i]))) | |||||
| .release(); | |||||
| } | } | ||||
| auto ret = pyf(py::cast(ctx.op), *args); | |||||
| auto ret = py::reinterpret_steal<py::object>( | |||||
| PyObject_Call(pyf, args.ptr(), nullptr)); | |||||
| // assumption: python function always returns PyList | // assumption: python function always returns PyList | ||||
| auto tup = py::reinterpret_borrow<py::list>(ret); | auto tup = py::reinterpret_borrow<py::list>(ret); | ||||