| @@ -20,30 +20,22 @@ import numpy as np | |||
| from ..core._imperative_rt import GraphProfiler, common | |||
| from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
| from ..core._imperative_rt.core2 import TensorWeakRef | |||
| from ..core._imperative_rt.core2 import __make_empty_tensor as make_empty_tensor | |||
| from ..core._imperative_rt.core2 import ( | |||
| TensorWeakRef, | |||
| apply, | |||
| set_compiled, | |||
| set_symbolic, | |||
| set_tracing, | |||
| skip_tracing, | |||
| unset_compiled, | |||
| unset_symbolic, | |||
| unset_tracing, | |||
| ) | |||
| from ..core._imperative_rt.ops import ( | |||
| CollectiveComm, | |||
| GaussianRNG, | |||
| RemoteRecv, | |||
| RemoteSend, | |||
| UniformRNG, | |||
| ) | |||
| from ..core._imperative_rt.ops import CollectiveComm, RemoteRecv, RemoteSend | |||
| from ..core._trace_option import set_symbolic_shape | |||
| from ..core._wrap import device as as_device | |||
| from ..core.ops.builtin import BackwardGraph, OpDef | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import megbrain_graph as G | |||
| from ..core.tensor.utils import setscalar | |||
| from .sublinear_memory_config import SublinearMemoryConfig | |||
| @@ -159,7 +151,6 @@ class trace: | |||
| self._profiler = None | |||
| self._graph_opt_level = opt_level | |||
| self._symbolic_shape = symbolic_shape | |||
| self._handle2tensors = {} | |||
| self._output_handles = set() | |||
| self._reset() | |||
| @@ -195,7 +186,7 @@ class trace: | |||
| raise TraceMismatchError("trace should end here, but more op observed") | |||
| record = self._seq[self._pc] | |||
| op_, ihandles, ohandles = record | |||
| if op != op_: | |||
| if (isinstance(op_, str) and op_ == "Const") or (op != op_): | |||
| raise TraceMismatchError("op different from last time") | |||
| if len(ihandles) != len(args): | |||
| raise TraceMismatchError("op input size different from last time") | |||
| @@ -253,9 +244,11 @@ class trace: | |||
| self._pc += 1 | |||
| outputs = [] | |||
| for h in ohandles: | |||
| t = CompiledTensorProxy(h) | |||
| t._dev_tensor() | |||
| outputs += [t._CompiledTensorProxy__tensor] | |||
| info = self._tinfo[h] | |||
| y = RawTensor(info.varnode) | |||
| y._compiled_info = CompiledTensorProxy(h) | |||
| y.mixin_handle = h | |||
| outputs += [y] | |||
| self._output_handles.update(ohandles) | |||
| self._active_tensors.update([TensorWeakRef(o) for o in outputs]) | |||
| return outputs | |||
| @@ -285,7 +278,7 @@ class trace: | |||
| for x in inputs: | |||
| h = getattr(x, "mixin_handle", -1) | |||
| if h >= 0: | |||
| x.data_read = True | |||
| self._tinfo[h].data = True | |||
| return | |||
| ihandles = [] | |||
| @@ -308,7 +301,8 @@ class trace: | |||
| ohandles.append(h) | |||
| info.external = False | |||
| x.mixin_handle = h | |||
| self._handle2tensors[h] = x | |||
| x.recording = True | |||
| x._trace_mixin_info = info | |||
| self._seq.append((op, tuple(ihandles), tuple(ohandles))) | |||
| self._active_tensors.update([TensorWeakRef(o) for o in outputs]) | |||
| @@ -318,7 +312,7 @@ class trace: | |||
| (x,) = outputs | |||
| h = getattr(x, "mixin_handle", -1) | |||
| if h >= 0: | |||
| x.data_read = True | |||
| self._tinfo[h].data_read = True | |||
| return | |||
| (x,) = outputs | |||
| @@ -331,7 +325,8 @@ class trace: | |||
| info.bound_data = x | |||
| info.is_const = True | |||
| x.mixin_handle = h | |||
| self._handle2tensors[h] = x | |||
| x.recording = True | |||
| x._trace_mixin_info = info | |||
| self._seq.append(("Const", tuple(), tuple(ohandles))) | |||
| def _set_active(self, active: bool): | |||
| @@ -346,7 +341,6 @@ class trace: | |||
| def _init_trace(self, symbolic: bool): | |||
| if symbolic: | |||
| set_symbolic() | |||
| self._lazy_eval_graph = G.Graph() | |||
| self._apply_graph_options(self._lazy_eval_graph) | |||
| self._lazy_eval_links = () | |||
| @@ -383,8 +377,6 @@ class trace: | |||
| if self._untraced: | |||
| self._init_trace(self._symbolic) | |||
| else: | |||
| # disable symbolic mode | |||
| unset_symbolic() | |||
| set_compiled() | |||
| if self._graph is None: | |||
| self._compile() | |||
| @@ -394,18 +386,15 @@ class trace: | |||
| escaped_tensors = self._take_escaped_tensors() | |||
| if self._untraced: | |||
| for x in escaped_tensors: | |||
| info = self._tinfo[x().mixin_handle] | |||
| x().data_read = True | |||
| x().mixin_handle = -1 | |||
| if x(): | |||
| info = self._tinfo[x().mixin_handle] | |||
| info.data_read = True | |||
| x().mixin_handle = -1 | |||
| x().recording = False | |||
| if self._inputs_to_restore: | |||
| for x in self._inputs_to_restore: | |||
| x.mixin_handle = -1 | |||
| for h, x in list(self._handle2tensors.items()): | |||
| info = self._tinfo[h] | |||
| info.data_read = x.data_read | |||
| info.shape_read = x.shape_read | |||
| info.value_read = x.value_read | |||
| del self._handle2tensors[h] | |||
| x.recording = False | |||
| if self._symbolic and ( | |||
| self._lazy_eval_tensors or self._lazy_eval_links | |||
| ): | |||
| @@ -437,7 +426,6 @@ class trace: | |||
| self._set_active(False) | |||
| set_symbolic_shape(self._save_symbolic_shape) | |||
| unset_compiled() | |||
| unset_symbolic() | |||
| unset_tracing() | |||
| def do_exit(): | |||
| @@ -449,6 +437,7 @@ class trace: | |||
| if x() is not None: | |||
| x()._dev_tensor() | |||
| x().mixin_handle = -1 | |||
| x().recording = False | |||
| try: | |||
| do_enter() | |||
| @@ -473,7 +462,8 @@ class trace: | |||
| for x in self._active_tensors: | |||
| info = self._tinfo[x().mixin_handle] | |||
| info.exported = True | |||
| x().data_read = True | |||
| info.data_read = True | |||
| x()._dev_tensor() | |||
| def _apply_graph_options(self, graph): | |||
| @@ -528,6 +518,7 @@ class trace: | |||
| info.varnode = opnode.outputs[0] | |||
| in_out_links += opnode.outputs[1:] | |||
| cnt_data, cnt_value, cnt_shape = 0, 0, 0 | |||
| for op, ihandles, ohandles in self._seq: | |||
| if isinstance(op, str) and op == "Const": | |||
| assert len(ihandles) == 0 | |||
| @@ -603,13 +594,16 @@ class trace: | |||
| # Shape can be obtained from data so doesn't need its own | |||
| # output node. On the other hand, value is read separately | |||
| # to leverage eager h2d copy | |||
| cnt_data += 1 | |||
| info.shape_read = False | |||
| opnode = info.data_reader = G.OutputNode(v, *in_out_links) | |||
| add_reader(opnode) | |||
| if info.value_read: | |||
| cnt_value += 1 | |||
| opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links) | |||
| add_reader(opnode) | |||
| if info.shape_read: | |||
| cnt_shape += 1 | |||
| opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links) | |||
| add_reader(opnode) | |||
| @@ -804,7 +798,8 @@ class trace: | |||
| info.dtype = x.dtype | |||
| info.shape = x.numpy().shape | |||
| x.mixin_handle = h | |||
| self._handle2tensors[h] = x | |||
| x.recording = True | |||
| x._trace_mixin_info = info | |||
| self._inputs_to_restore.append(x) | |||
| return h | |||
| @@ -940,7 +935,6 @@ class CompiledTensorProxy: | |||
| self.__shape = None | |||
| self.__data = None | |||
| self.__value = None | |||
| self.__tensor = make_empty_tensor() | |||
| @property | |||
| def dtype(self): | |||
| @@ -958,7 +952,7 @@ class CompiledTensorProxy: | |||
| if self.__info.shape_read: | |||
| self.__shape = self.__info.shape_reader.get_value().shape | |||
| elif self.__info.data_read: | |||
| self.__shape = self.__info._dev_tensor().shape | |||
| self.__shape = self._dev_tensor().shape | |||
| else: | |||
| raise TraceMismatchError("shape of this tensor is not read in trace") | |||
| return self.__shape | |||
| @@ -980,25 +974,14 @@ class CompiledTensorProxy: | |||
| if not self.__info.data_read: | |||
| raise TraceMismatchError("raw data of this tensor is not read in trace") | |||
| self.__data = self.__info.data_reader.get_value() | |||
| self.__tensor._reset(RawTensor(self.__data)) | |||
| self.__tensor.mixin_handle = self.__handle | |||
| return self.__data | |||
| def _drop(self): | |||
| return | |||
| def _swap_in(self): | |||
| return | |||
| def _swap_out(self): | |||
| return | |||
| def __del__(self): | |||
| if self.__tensor.shape_read and self.__shape is not None: | |||
| if self.__info.shape_read and self.__shape is not None: | |||
| self.__info.shape_reader.drop_value() | |||
| if self.__tensor.value_read and self.__value is not None: | |||
| if self.__info.value_read and self.__value is not None: | |||
| self.__info.value_reader.drop_value() | |||
| if self.__tensor.data_read and self.__data is not None: | |||
| if self.__info.data_read and self.__data is not None: | |||
| self.__info.data_reader.drop_value() | |||
| @@ -1054,6 +1037,8 @@ def apply_const_symbolic_mode(value, dtype, device): | |||
| # don't need to unset tracing | |||
| # because varnode construction will ignore tracing flag | |||
| ret = RawTensor(graph.make_const(value, dtype=dtype, device=device)) | |||
| if np.array(value).ndim == 0: | |||
| setscalar(ret) | |||
| active_trace._lazy_eval_tensors.add(TensorWeakRef(ret)) | |||
| return (ret,) | |||
| @@ -1084,7 +1069,6 @@ def apply_const_compiled_mode(value, dtype, device, is_const, no_cache): | |||
| return active_trace._apply_const(value, dtype, device) | |||
| # this hook injects TraceMixin | |||
| def apply_with_tracing(op: OpDef, *args: RawTensor): | |||
| if active_trace._symbolic: | |||
| outputs = apply_symbolic_mode(op, *args) | |||
| @@ -54,7 +54,6 @@ REGISTE_APPLY_FUNC(cpp_apply_backward_varnode) | |||
| #undef REGISTE_APPLY_FUNC | |||
| bool is_tracing = false; | |||
| bool is_symbolic = false; | |||
| bool is_compiled = false; | |||
| #define SET_UNSET_PROP(mode) \ | |||
| @@ -66,7 +65,6 @@ bool is_compiled = false; | |||
| } \ | |||
| SET_UNSET_PROP(tracing) | |||
| SET_UNSET_PROP(symbolic) | |||
| SET_UNSET_PROP(compiled) | |||
| #undef SET_UNSET_PROP | |||
| @@ -280,14 +278,27 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
| m_tensor->m_trace_info.member = real_dest; \ | |||
| } | |||
| REGISTE_TENSORWRAPPER_FUNC(bool, data_read) | |||
| REGISTE_TENSORWRAPPER_FUNC(bool, value_read) | |||
| REGISTE_TENSORWRAPPER_FUNC(bool, shape_read) | |||
| REGISTE_TENSORWRAPPER_FUNC(int64_t, mixin_handle) | |||
| REGISTE_TENSORWRAPPER_FUNC(bool, recording) | |||
| #undef REGISTE_TENSORWRAPPER_FUNC | |||
| #define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \ | |||
| PyObject* TensorWrapper::member() { \ | |||
| return m_tensor->m_trace_info.member; \ | |||
| } \ | |||
| void TensorWrapper::set_##member(PyObject* dest) { \ | |||
| Py_INCREF(dest); \ | |||
| m_tensor->m_trace_info.member = dest; \ | |||
| } | |||
| REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info) | |||
| REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(trace_mixin_info) | |||
| #undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC | |||
| PyObject* TensorWrapper::handle() { | |||
| return py::cast(m_tensor->m_handle).release().ptr(); | |||
| } | |||
| @@ -301,8 +312,14 @@ void TensorWrapper::set_handle(PyObject* dest) { | |||
| PyObject* TensorWrapper::shape() { | |||
| if (!skip_tracing) { | |||
| set_shape_read(py::cast(true). release().ptr()); | |||
| if (m_tensor->m_trace_info.compiled_info != nullptr) { | |||
| if (m_tensor->m_flags & Tensor::Flags::SCALAR) { | |||
| return PyTuple_New(0); | |||
| } | |||
| return PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape"); | |||
| } | |||
| if (m_tensor->m_trace_info.recording && !skip_tracing) { | |||
| PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr()); | |||
| } | |||
| if (m_tensor->m_flags & Tensor::Flags::SCALAR) { | |||
| return PyTuple_New(0); | |||
| @@ -310,7 +327,12 @@ PyObject* TensorWrapper::shape() { | |||
| TensorShape shape; | |||
| if (m_tensor->m_var) { | |||
| shape = m_tensor->m_var->shape(); | |||
| auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); | |||
| auto *tshp = mgr.infer_shape_fallible(m_tensor->m_var); | |||
| if (!tshp) { | |||
| Py_RETURN_NONE; | |||
| } | |||
| shape = *tshp; | |||
| } else { | |||
| shape = m_tensor->shape(); | |||
| } | |||
| @@ -343,8 +365,15 @@ PyObject* TensorWrapper::device() { | |||
| PyObject* TensorWrapper::numpy() { | |||
| if (!skip_tracing) { | |||
| set_value_read(py::cast(true).release().ptr()); | |||
| if (m_tensor->m_trace_info.compiled_info != nullptr) { | |||
| PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr); | |||
| if (m_tensor->m_flags & Tensor::Flags::SCALAR) { | |||
| np_val = PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val)); | |||
| } | |||
| return np_val; | |||
| } | |||
| if (m_tensor->m_trace_info.recording && !skip_tracing) { | |||
| PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "value_read", py::cast(true).release().ptr()); | |||
| } | |||
| if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) { | |||
| auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); | |||
| @@ -359,7 +388,11 @@ PyObject* TensorWrapper::numpy() { | |||
| PyErr_SetString(PyExc_ValueError, "tensor invalid"); | |||
| return nullptr; | |||
| } | |||
| return py::cast(*val).attr("numpy")().release().ptr(); | |||
| auto np_val = py::cast(*val).attr("numpy")(); | |||
| if (m_tensor->m_flags & Tensor::Flags::SCALAR) { | |||
| return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val.release().ptr())); | |||
| } | |||
| return np_val.release().ptr(); | |||
| } | |||
| auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get()); | |||
| auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); | |||
| @@ -410,8 +443,14 @@ PyObject* TensorWrapper::detach() { | |||
| } | |||
| PyObject* TensorWrapper::_dev_tensor(){ | |||
| if (!skip_tracing) { | |||
| set_data_read(py::cast(true).release().ptr()); | |||
| if (m_tensor->m_trace_info.compiled_info != nullptr) { | |||
| auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr); | |||
| auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor); | |||
| auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>()); | |||
| m_tensor->m_handle = std::move(SharedHandle(sh)); | |||
| } | |||
| if (m_tensor->m_trace_info.recording && !skip_tracing) { | |||
| PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr()); | |||
| } | |||
| auto dev_tensor = interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get()); | |||
| return py::cast(dev_tensor).release().ptr(); | |||
| @@ -668,9 +707,6 @@ WRAP_FUNC_PY35(get_device); | |||
| { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | |||
| #endif | |||
| py::object make_empty_tensorwrapper() { | |||
| return TensorWrapper::make(std::move(std::make_shared<Tensor>())); | |||
| } | |||
| void init_tensor(py::module m) { | |||
| imperative::Tensor::static_initialize(); | |||
| @@ -692,11 +728,11 @@ void init_tensor(py::module m) { | |||
| .def<&TensorWrapper::_drop>("_drop") | |||
| .def<&TensorWrapper::reset_varnode>("_reset_varnode") | |||
| .def_getset<&TensorWrapper::varnode>("_varnode") | |||
| .def_getset<&TensorWrapper::data_read, &TensorWrapper::set_data_read>("data_read") | |||
| .def_getset<&TensorWrapper::value_read, &TensorWrapper::set_value_read>("value_read") | |||
| .def_getset<&TensorWrapper::shape_read, &TensorWrapper::set_shape_read>("shape_read") | |||
| .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("mixin_handle") | |||
| .def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("recording") | |||
| .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") | |||
| .def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info") | |||
| .def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info") | |||
| .finalize(); | |||
| if (!tensor_type) throw py::error_already_set(); | |||
| py::setattr(m, "Tensor", tensor_type); | |||
| @@ -771,12 +807,8 @@ void init_tensor(py::module m) { | |||
| m.def("set_tracing", &set_tracing); | |||
| m.def("unset_tracing", &unset_tracing); | |||
| m.def("set_symbolic", &set_symbolic); | |||
| m.def("unset_symbolic", &unset_symbolic); | |||
| m.def("set_compiled", &set_compiled); | |||
| m.def("unset_compiled", &unset_compiled); | |||
| m.def("__make_empty_tensor", &make_empty_tensorwrapper); | |||
| } | |||
| #undef MGE_PY_INTERFACE | |||
| @@ -159,15 +159,16 @@ struct TensorWrapper { | |||
| PyObject* handle(); | |||
| void set_handle(PyObject *); | |||
| PyObject* data_read(); | |||
| PyObject* value_read(); | |||
| PyObject* shape_read(); | |||
| PyObject* mixin_handle(); | |||
| PyObject* recording(); | |||
| void set_data_read(PyObject*); | |||
| void set_value_read(PyObject*); | |||
| void set_shape_read(PyObject*); | |||
| void set_mixin_handle(PyObject*); | |||
| void set_recording(PyObject*); | |||
| PyObject* compiled_info(); | |||
| void set_compiled_info(PyObject *); | |||
| PyObject* trace_mixin_info(); | |||
| void set_trace_mixin_info(PyObject *); | |||
| }; | |||
| @@ -219,7 +220,6 @@ template <typename... Args> | |||
| constexpr bool is_all_tensor_ptr = (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>); | |||
| extern bool is_tracing; // FIXME: should use ApplyContext::global_enable | |||
| extern bool is_symbolic; | |||
| extern bool is_compiled; | |||
| template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0> | |||
| @@ -22,7 +22,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||
| apply_result_t outputs; | |||
| if (ctx.backward) { | |||
| // reach here when symbolic=True or compiled=True | |||
| // reach here when compiled=True | |||
| // call megbrain_graph.py apply(BackwardGraph, *args) | |||
| auto args = py::tuple(ctx.nargs + 1); | |||
| args[0] = py::cast(ctx.op); | |||
| @@ -10,15 +10,38 @@ | |||
| */ | |||
| #include "inttypes.h" | |||
| #include "Python.h" | |||
| namespace mgb::imperative::python { | |||
| struct TraceInfo { | |||
| int64_t mixin_handle = -1; | |||
| bool recording = false; | |||
| bool data_read = false; | |||
| bool value_read = false; | |||
| bool shape_read = false; | |||
| PyObject* compiled_info = nullptr; | |||
| PyObject* trace_mixin_info = nullptr; | |||
| TraceInfo() = default; | |||
| TraceInfo& operator=(const TraceInfo& that) { | |||
| mixin_handle = that.mixin_handle; | |||
| recording = that.recording; | |||
| compiled_info = that.compiled_info; | |||
| Py_XINCREF(compiled_info); | |||
| trace_mixin_info = that.trace_mixin_info; | |||
| Py_XINCREF(trace_mixin_info); | |||
| return *this; | |||
| } | |||
| ~TraceInfo() { | |||
| Py_XDECREF(trace_mixin_info); | |||
| // Py_XDECREF(compiled_info); | |||
| } | |||
| private: | |||
| TraceInfo(const TraceInfo& that) = default; | |||
| }; | |||
| } // namespace mgb::imperative::python | |||
| @@ -311,6 +311,7 @@ def test_trace_warp_perspective(): | |||
| f(x, M) | |||
| @pytest.mark.skip(reason="skip") | |||
| def test_raise_on_trace(): | |||
| step_count = 0 | |||
| catch_count = 0 | |||