| @@ -350,12 +350,16 @@ class trace: | |||||
| self._lazy_eval_links = () | self._lazy_eval_links = () | ||||
| def _take_escaped_tensors(self): | def _take_escaped_tensors(self): | ||||
| escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors.values())) | |||||
| escaped_tensors = tuple( | |||||
| filter(lambda x: x() is not None, self._active_tensors.values()) | |||||
| ) | |||||
| self._active_tensors.clear() | self._active_tensors.clear() | ||||
| return escaped_tensors | return escaped_tensors | ||||
| def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): | def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): | ||||
| lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors.values())) | |||||
| lazy_eval_tensors = list( | |||||
| filter(lambda x: x() is not None, lazy_eval_tensors.values()) | |||||
| ) | |||||
| readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] | readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] | ||||
| self._apply_graph_options(lazy_eval_graph) | self._apply_graph_options(lazy_eval_graph) | ||||
| # FIXME | # FIXME | ||||
| @@ -443,6 +447,7 @@ class trace: | |||||
| x()._reset_varnode() | x()._reset_varnode() | ||||
| x().mixin_handle = -1 | x().mixin_handle = -1 | ||||
| x().recording = False | x().recording = False | ||||
| x()._trace_mixin_info = None | |||||
| try: | try: | ||||
| do_enter() | do_enter() | ||||
| @@ -294,8 +294,13 @@ PyObject* TensorWrapper::copied() { | |||||
| return m_tensor->m_trace_info.member; \ | return m_tensor->m_trace_info.member; \ | ||||
| } \ | } \ | ||||
| void TensorWrapper::set_##member(PyObject* dest) { \ | void TensorWrapper::set_##member(PyObject* dest) { \ | ||||
| Py_INCREF(dest); \ | |||||
| m_tensor->m_trace_info.member = dest; \ | |||||
| if (dest == Py_None) { \ | |||||
| Py_XDECREF(m_tensor->m_trace_info.member); \ | |||||
| m_tensor->m_trace_info.member = nullptr; \ | |||||
| } else { \ | |||||
| Py_INCREF(dest); \ | |||||
| m_tensor->m_trace_info.member = dest; \ | |||||
| } \ | |||||
| } | } | ||||
| REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info) | REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info) | ||||
| @@ -463,6 +468,8 @@ PyObject* TensorWrapper::_dev_tensor(){ | |||||
| auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor); | auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor); | ||||
| auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>()); | auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>()); | ||||
| m_tensor->m_handle = std::move(SharedHandle(sh)); | m_tensor->m_handle = std::move(SharedHandle(sh)); | ||||
| Py_DECREF(m_tensor->m_trace_info.compiled_info); | |||||
| m_tensor->m_trace_info.compiled_info = nullptr; | |||||
| return dev_tensor; | return dev_tensor; | ||||
| } | } | ||||
| @@ -55,10 +55,9 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||||
| auto args = py::tuple(ctx.nargs + 1); | auto args = py::tuple(ctx.nargs + 1); | ||||
| args[0] = py::cast(ctx.op); | args[0] = py::cast(ctx.op); | ||||
| py::tuple args(ctx.nargs); | |||||
| for (size_t i = 0; i < ctx.nargs; i++) { | for (size_t i = 0; i < ctx.nargs; i++) { | ||||
| args[i + 1] = TensorWrapper::make( | |||||
| std::move(std::shared_ptr<Tensor>(ctx.args[i]))) | |||||
| .release(); | |||||
| args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); | |||||
| } | } | ||||
| auto ret = py::reinterpret_steal<py::object>( | auto ret = py::reinterpret_steal<py::object>( | ||||
| PyObject_Call(pyf, args.ptr(), nullptr)); | PyObject_Call(pyf, args.ptr(), nullptr)); | ||||
| @@ -28,10 +28,10 @@ struct TraceInfo { | |||||
| mixin_handle = that.mixin_handle; | mixin_handle = that.mixin_handle; | ||||
| recording = that.recording; | recording = that.recording; | ||||
| compiled_info = that.compiled_info; | |||||
| Py_XINCREF(compiled_info); | |||||
| trace_mixin_info = that.trace_mixin_info; | trace_mixin_info = that.trace_mixin_info; | ||||
| Py_XINCREF(trace_mixin_info); | Py_XINCREF(trace_mixin_info); | ||||
| compiled_info = that.compiled_info; | |||||
| Py_XINCREF(compiled_info); | |||||
| copied = true; | copied = true; | ||||
| return *this; | return *this; | ||||
| @@ -39,7 +39,7 @@ struct TraceInfo { | |||||
| ~TraceInfo() { | ~TraceInfo() { | ||||
| Py_XDECREF(trace_mixin_info); | Py_XDECREF(trace_mixin_info); | ||||
| // Py_XDECREF(compiled_info); | |||||
| Py_XDECREF(compiled_info); | |||||
| } | } | ||||
| private: | private: | ||||