GitOrigin-RevId: fd1265c661
tags/v1.5.0
| @@ -170,9 +170,9 @@ class trace: | |||
| self._graph = None | |||
| self._need_reset_nodes = None | |||
| self._lazy_eval_graph = None | |||
| self._lazy_eval_tensors = {} | |||
| self._lazy_eval_tensors = set() | |||
| self._lazy_eval_links = None | |||
| self._active_tensors = {} | |||
| self._active_tensors = set() | |||
| self._tensor_remaps = None | |||
| self._inputs_to_restore = None | |||
| self._arg_bindings = None | |||
| @@ -258,7 +258,7 @@ class trace: | |||
| y._compiled_info = CompiledTensorProxy(h) | |||
| y._mixin_handle = h | |||
| outputs += [y] | |||
| self._active_tensors[h] = TensorWeakRef(y) | |||
| self._active_tensors.add(TensorWeakRef(y)) | |||
| self._output_handles.update(ohandles) | |||
| return outputs | |||
| @@ -318,9 +318,9 @@ class trace: | |||
| x._mixin_handle = h | |||
| x._recording = True | |||
| x._trace_mixin_info = info | |||
| self._active_tensors[h] = TensorWeakRef(x) | |||
| self._active_tensors.add(TensorWeakRef(x)) | |||
| if self._symbolic: | |||
| self._lazy_eval_tensors[h] = TensorWeakRef(x) | |||
| self._lazy_eval_tensors.add(TensorWeakRef(x)) | |||
| self._seq.append((op, tuple(ihandles), tuple(ohandles))) | |||
| @@ -345,7 +345,7 @@ class trace: | |||
| x._recording = True | |||
| x._trace_mixin_info = info | |||
| if self._symbolic: | |||
| self._lazy_eval_tensors[h] = TensorWeakRef(x) | |||
| self._lazy_eval_tensors.add(TensorWeakRef(x)) | |||
| self._seq.append(("Const", tuple(), tuple(ohandles))) | |||
| def _set_active(self, active: bool): | |||
| @@ -365,17 +365,14 @@ class trace: | |||
| self._lazy_eval_links = () | |||
| def _take_escaped_tensors(self): | |||
| escaped_tensors = tuple( | |||
| filter(lambda x: x() is not None, self._active_tensors.values()) | |||
| ) | |||
| escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors)) | |||
| self._active_tensors.clear() | |||
| return escaped_tensors | |||
| def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): | |||
| lazy_eval_tensors = list( | |||
| filter(lambda x: x() is not None, lazy_eval_tensors.values()) | |||
| ) | |||
| readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] | |||
| lazy_eval_tensors = [x() for x in lazy_eval_tensors] | |||
| lazy_eval_tensors = [x for x in lazy_eval_tensors if x is not None] | |||
| readers = [G.OutputNode(x._varnode).outputs[0] for x in lazy_eval_tensors] | |||
| self._apply_graph_options(lazy_eval_graph) | |||
| lazy_eval_graph.options.graph_opt_level = self._graph_opt_level | |||
| lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers]) | |||
| @@ -383,8 +380,8 @@ class trace: | |||
| lazy_eval_graph() | |||
| for r, x in zip(readers, lazy_eval_tensors): | |||
| # get values from lazy_eval_graph and assign to lazy_eval tensor | |||
| x()._handle = RawTensor(r.op.get_value())._handle | |||
| x()._reset_varnode() | |||
| x._handle = RawTensor(r.op.get_value())._handle | |||
| x._reset_varnode() | |||
| @contextlib.contextmanager | |||
| def _setup(self): | |||
| @@ -454,13 +451,14 @@ class trace: | |||
| raise TraceMismatchError("premature end") | |||
| if not self._symbolic or not self._untraced: | |||
| # reset output tensors | |||
| for x in self._active_tensors.values(): | |||
| if x() is not None: | |||
| x()._dev_tensor() | |||
| x()._reset_varnode() | |||
| x()._mixin_handle = -1 | |||
| x()._recording = False | |||
| x()._trace_mixin_info = None | |||
| for x in self._active_tensors.copy(): | |||
| strong_x = x() | |||
| if strong_x is not None: | |||
| strong_x._dev_tensor() | |||
| strong_x._reset_varnode() | |||
| strong_x._mixin_handle = -1 | |||
| strong_x._recording = False | |||
| strong_x._trace_mixin_info = None | |||
| try: | |||
| do_enter() | |||
| @@ -482,15 +480,17 @@ class trace: | |||
| if self._untraced: | |||
| # conditionally reading a compiled tensor in excluded region | |||
| # is permitted, so we have to assume every tensor might be read | |||
| for x in self._active_tensors.values(): | |||
| if x(): | |||
| info = self._tinfo[x()._mixin_handle] | |||
| for x in self._active_tensors: | |||
| strong_x = x() | |||
| if strong_x: | |||
| info = self._tinfo[strong_x._mixin_handle] | |||
| info.exported = True | |||
| info.data_read = True | |||
| else: | |||
| for x in self._active_tensors.values(): | |||
| if x(): | |||
| x()._dev_tensor() | |||
| for x in self._active_tensors: | |||
| strong_x = x() | |||
| if strong_x: | |||
| strong_x._dev_tensor() | |||
| def _apply_graph_options(self, graph): | |||
| @@ -520,7 +520,6 @@ class trace: | |||
| graph = self._graph = G.Graph() | |||
| graph.options.async_exec_level = 0b100 | |||
| self._apply_graph_options(graph) | |||
| # graph.options.graph_opt_level = 0 | |||
| need_reset_nodes = self._need_reset_nodes = [] | |||
| # links enforce ordering of I/O nodes | |||
| in_out_links = () | |||
| @@ -563,7 +562,7 @@ class trace: | |||
| if not hasattr(info, "varnode"): | |||
| assert info.external | |||
| if info.bound_data: | |||
| if hasattr(info, "is_const") and info.is_const: | |||
| if getattr(info, "is_const", False): | |||
| info.varnode = graph.make_const( | |||
| info.bound_data.numpy(), | |||
| info.bound_data.dtype, | |||
| @@ -635,30 +634,12 @@ class trace: | |||
| opnode.reset() | |||
| def __call__(self, *args, **kwargs): | |||
| if is_tracing(): | |||
| return self.__wrapped__(*args, **kwargs) | |||
| with self._setup(): | |||
| if self._capture_as_const: | |||
| self._process_inputs(*args, **kwargs) | |||
| outputs = self.__wrapped__(*args, **kwargs) | |||
| if self._capture_as_const: | |||
| self._process_outputs(outputs) | |||
| # outputs could be None | |||
| if outputs is not None: | |||
| list_outputs = outputs | |||
| if isinstance(outputs, collections.abc.Mapping): | |||
| _, list_outputs = zip(*sorted(outputs.items())) | |||
| elif not isinstance(outputs, collections.abc.Sequence): | |||
| list_outputs = (outputs,) | |||
| for o in list_outputs: | |||
| # if outputs are copied, then use the newest info in trace data structure | |||
| if o._copied: | |||
| self._active_tensors[o._mixin_handle] = TensorWeakRef(o) | |||
| if self._untraced and self._symbolic: | |||
| self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o) | |||
| return outputs | |||
| def dump( | |||
| @@ -9,11 +9,12 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma GCC diagnostic ignored "-Wmissing-field-initializers" | |||
| #include "./grad.h" | |||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||
| #include "megbrain/imperative/backward_graph_opt.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/imperative/ops/utility.h" | |||
| #include "megbrain/utils/mempool.h" | |||
| #include "range/v3/all.hpp" | |||
| @@ -434,7 +435,8 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||
| if (backward.output_requires_grad(i)) { | |||
| if (backward.output_captured(i)) { | |||
| // avoid reference cycle [Tensor <-> GradFn] | |||
| outputs[i] = outputs[i]->copy(); | |||
| static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new FastpathCopy()); | |||
| outputs[i] = python::apply(op, outputs[i])[0]; | |||
| } | |||
| // populate grad info of output tensor | |||
| auto& grad_info = outputs[i]->m_grad_info; | |||
| @@ -12,6 +12,7 @@ | |||
| #pragma once | |||
| #include "./tensor.h" | |||
| #include "megbrain/imperative/ops/utility.h" | |||
| #include <megbrain/utils/small_vector.h> | |||
| #include <memory> | |||
| @@ -221,6 +221,21 @@ apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& ma | |||
| return apply(ctx); | |||
| } | |||
| apply_result_t fastpathcopy_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
| mgb_assert(ctx.nargs == 1); | |||
| maker.output_size(1).output_captured(0, false); | |||
| maker.backward([](BackwardContext&, Tensor*const* grads, size_t ngrads) { | |||
| mgb_assert(ngrads == 1); | |||
| Tensor* grad = grads[0]; | |||
| apply_result_t ret(1); | |||
| if (grad) { | |||
| ret[0] = grad->shared_from_this(); | |||
| } | |||
| return ret; | |||
| }); | |||
| return apply(ctx); | |||
| } | |||
| struct Init { | |||
| Init() { | |||
| auto& reg = grad_rule_registry(); | |||
| @@ -231,6 +246,7 @@ struct Init { | |||
| reg.emplace(Reduce::typeinfo(), reduce_grad_rule); | |||
| reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule); | |||
| reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule); | |||
| reg.emplace(FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | |||
| } | |||
| } _; | |||
| @@ -23,6 +23,7 @@ | |||
| #include "./common.h" | |||
| #include "./ops.h" | |||
| #include "megbrain/gopt/inference.h" | |||
| #include "megbrain/imperative/ops/utility.h" | |||
| namespace py = pybind11; | |||
| @@ -118,9 +118,18 @@ apply_result_t apply(ApplyContext& ctx) { | |||
| handles[i] = ctx.args[i]->m_handle.get(); | |||
| } | |||
| apply_result_t outputs; | |||
| // fast copy without really applying | |||
| if (ctx.op->same_type<FastpathCopy>()) { | |||
| mgb_assert(ctx.nargs == 1); | |||
| outputs.reserve(ctx.nargs); | |||
| outputs.emplace_back(std::make_shared<Tensor>(ctx.args[0]->m_handle)); | |||
| return outputs; | |||
| } | |||
| auto output_handles = interpreter_for_py->apply_op(ctx.op, handles); | |||
| apply_result_t outputs; | |||
| outputs.reserve(output_handles.size()); | |||
| for (auto h : output_handles) { | |||
| outputs.emplace_back(std::make_shared<Tensor>(h)); | |||
| @@ -303,11 +312,6 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording) | |||
| #undef REGISTE_TENSORWRAPPER_FUNC | |||
| PyObject* TensorWrapper::copied() { | |||
| return py::cast(m_tensor->m_trace_info.copied).release().ptr(); | |||
| } | |||
| #define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \ | |||
| PyObject* TensorWrapper::member() { \ | |||
| if (m_tensor->m_trace_info.member) { \ | |||
| @@ -841,7 +845,6 @@ void init_tensor(py::module m) { | |||
| .def<&TensorWrapper::reset_varnode>("_reset_varnode") | |||
| .def<&TensorWrapper::_use_cnt>("_use_cnt") | |||
| .def_getset<&TensorWrapper::varnode>("_varnode") | |||
| .def_getset<&TensorWrapper::copied>("_copied") | |||
| .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle") | |||
| .def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("_recording") | |||
| .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") | |||
| @@ -10,6 +10,7 @@ | |||
| */ | |||
| #pragma once | |||
| #pragma GCC diagnostic ignored "-Wmissing-field-initializers" | |||
| #include <variant> | |||
| @@ -35,7 +35,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||
| // assumption: python function always returns PyList | |||
| auto tup = py::reinterpret_borrow<py::list>(ret); | |||
| for (auto i = 0; i < tup.size(); i++) { | |||
| for (size_t i = 0; i < tup.size(); i++) { | |||
| auto pitem = tup[i].cast<cg::VarNode*>(); | |||
| outputs.emplace_back(std::make_shared<Tensor>(pitem)); | |||
| } | |||
| @@ -17,7 +17,6 @@ namespace mgb::imperative::python { | |||
| struct TraceInfo { | |||
| int64_t mixin_handle = -1; | |||
| bool recording = false; | |||
| bool copied = false; | |||
| // refer to CompiledTensorProxy in tracing.py, works from second trace step | |||
| PyObject* compiled_info = nullptr; | |||
| @@ -35,7 +34,6 @@ struct TraceInfo { | |||
| compiled_info = that.compiled_info; | |||
| Py_XINCREF(compiled_info); | |||
| copied = true; | |||
| return *this; | |||
| } | |||
| @@ -18,4 +18,18 @@ namespace mgb::imperative { | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); | |||
| namespace { namespace fastpathcopy { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| return inputs; | |||
| } | |||
| OP_TRAIT_REG(FastpathCopy,FastpathCopy) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // fastpathcopy | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(FastpathCopy); | |||
| } // namespace mgb::imperative | |||
| @@ -35,4 +35,18 @@ struct GenericPyOp final : OpDefImplBase<GenericPyOp> { | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| }; | |||
| struct FastpathCopy final : OpDefImplBase<FastpathCopy> { | |||
| FastpathCopy() = default; | |||
| size_t hash() const override { | |||
| return mgb::hash(this->dyn_typeinfo()); | |||
| } | |||
| bool is_same_st(const Hashable& rhs) const override { | |||
| return this->dyn_typeinfo() == rhs.dyn_typeinfo(); | |||
| } | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| }; | |||
| } // namespace mgb::imperative | |||