Browse Source

fix(mge): correct trace outputs when grad does copy

GitOrigin-RevId: 65c8956a7d
tags/v1.2.0
Megvii Engine Team 5 years ago
parent
commit
c70a49ed2c
4 changed files with 35 additions and 11 deletions
  1. +26
    -11
      imperative/python/megengine/jit/tracing.py
  2. +6
    -0
      imperative/python/src/tensor.cpp
  3. +1
    -0
      imperative/python/src/tensor.h
  4. +2
    -0
      imperative/python/src/trace_info.h

+ 26
- 11
imperative/python/megengine/jit/tracing.py View File

@@ -163,9 +163,9 @@ class trace:
self._graph = None
self._need_reset_nodes = None
self._lazy_eval_graph = None
self._lazy_eval_tensors = set()
self._lazy_eval_tensors = {}
self._lazy_eval_links = None
self._active_tensors = set()
self._active_tensors = {}
self._tensor_remaps = None
self._inputs_to_restore = None
self._arg_bindings = None
@@ -249,8 +249,8 @@ class trace:
y._compiled_info = CompiledTensorProxy(h)
y.mixin_handle = h
outputs += [y]
self._active_tensors[h] = TensorWeakRef(y)
self._output_handles.update(ohandles)
self._active_tensors.update([TensorWeakRef(o) for o in outputs])
return outputs

def _apply_const(self, value, dtype, device):
@@ -303,9 +303,11 @@ class trace:
x.mixin_handle = h
x.recording = True
x._trace_mixin_info = info
self._active_tensors[h] = TensorWeakRef(x)
if self._symbolic:
self._lazy_eval_tensors[h] = TensorWeakRef(x)

self._seq.append((op, tuple(ihandles), tuple(ohandles)))
self._active_tensors.update([TensorWeakRef(o) for o in outputs])

def _record_const(self, outputs):
if skip_tracing:
@@ -327,6 +329,8 @@ class trace:
x.mixin_handle = h
x.recording = True
x._trace_mixin_info = info
if self._symbolic:
self._lazy_eval_tensors[h] = TensorWeakRef(x)
self._seq.append(("Const", tuple(), tuple(ohandles)))

def _set_active(self, active: bool):
@@ -346,12 +350,12 @@ class trace:
self._lazy_eval_links = ()

def _take_escaped_tensors(self):
escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors))
escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors.values()))
self._active_tensors.clear()
return escaped_tensors

def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors))
lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors.values()))
readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors]
self._apply_graph_options(lazy_eval_graph)
# FIXME
@@ -401,7 +405,7 @@ class trace:
# eval lazy eval tensors
self._lazy_eval(
self._lazy_eval_graph,
tuple(self._lazy_eval_tensors),
self._lazy_eval_tensors,
self._lazy_eval_links,
)
self._lazy_eval_graph = None
@@ -433,9 +437,10 @@ class trace:
if not self._untraced and self._pc != len(self._seq):
raise TraceMismatchError("premature end")
if not self._symbolic or not self._untraced:
for x in self._active_tensors:
for x in self._active_tensors.values():
if x() is not None:
x()._dev_tensor()
x()._reset_varnode()
x().mixin_handle = -1
x().recording = False

@@ -459,7 +464,7 @@ class trace:
if self._untraced:
# conditionally reading a compiled tensor in excluded region
# is permitted, so we have to assume every tensor might be read
for x in self._active_tensors:
for x in self._active_tensors.values():
info = self._tinfo[x().mixin_handle]
info.exported = True
info.data_read = True
@@ -626,8 +631,20 @@ class trace:
if self._capture_as_const:
self._process_inputs(*args, **kwargs)
outputs = self.__wrapped__(*args, **kwargs)
transform = False
if outputs is not None:
if not isinstance(outputs, collections.abc.Sequence):
transform = True
outputs = (outputs,)
for o in outputs:
if o._copied:
self._active_tensors[o.mixin_handle] = TensorWeakRef(o)
if self._untraced and self._symbolic:
self._lazy_eval_tensors[o.mixin_handle] = TensorWeakRef(o)
if self._capture_as_const:
self._process_outputs(outputs)
if transform:
outputs = outputs[0]
return outputs

def dump(
@@ -1031,7 +1048,6 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
if require_links:
active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),)

active_trace._lazy_eval_tensors.update([TensorWeakRef(o) for o in outputs])
return outputs


@@ -1042,7 +1058,6 @@ def apply_const_symbolic_mode(value, dtype, device):
ret = RawTensor(graph.make_const(value, dtype=dtype, device=device))
if np.array(value).ndim == 0:
setscalar(ret)
active_trace._lazy_eval_tensors.add(TensorWeakRef(ret))
return (ret,)




+ 6
- 0
imperative/python/src/tensor.cpp View File

@@ -284,6 +284,11 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording)
#undef REGISTE_TENSORWRAPPER_FUNC


PyObject* TensorWrapper::copied() {
return py::cast(m_tensor->m_trace_info.copied).release().ptr();
}


#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
PyObject* TensorWrapper::member() { \
return m_tensor->m_trace_info.member; \
@@ -740,6 +745,7 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::_drop>("_drop")
.def<&TensorWrapper::reset_varnode>("_reset_varnode")
.def_getset<&TensorWrapper::varnode>("_varnode")
.def_getset<&TensorWrapper::copied>("_copied")
.def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("mixin_handle")
.def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("recording")
.def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle")


+ 1
- 0
imperative/python/src/tensor.h View File

@@ -161,6 +161,7 @@ struct TensorWrapper {

PyObject* mixin_handle();
PyObject* recording();
PyObject* copied();

void set_mixin_handle(PyObject*);
void set_recording(PyObject*);


+ 2
- 0
imperative/python/src/trace_info.h View File

@@ -17,6 +17,7 @@ namespace mgb::imperative::python {
struct TraceInfo {
int64_t mixin_handle = -1;
bool recording = false;
bool copied = false;

PyObject* compiled_info = nullptr;
PyObject* trace_mixin_info = nullptr;
@@ -32,6 +33,7 @@ struct TraceInfo {
trace_mixin_info = that.trace_mixin_info;
Py_XINCREF(trace_mixin_info);

copied = true;
return *this;
}



Loading…
Cancel
Save