|
|
|
@@ -163,9 +163,9 @@ class trace: |
|
|
|
self._graph = None |
|
|
|
self._need_reset_nodes = None |
|
|
|
self._lazy_eval_graph = None |
|
|
|
self._lazy_eval_tensors = set() |
|
|
|
self._lazy_eval_tensors = {} |
|
|
|
self._lazy_eval_links = None |
|
|
|
self._active_tensors = set() |
|
|
|
self._active_tensors = {} |
|
|
|
self._tensor_remaps = None |
|
|
|
self._inputs_to_restore = None |
|
|
|
self._arg_bindings = None |
|
|
|
@@ -249,8 +249,8 @@ class trace: |
|
|
|
y._compiled_info = CompiledTensorProxy(h) |
|
|
|
y.mixin_handle = h |
|
|
|
outputs += [y] |
|
|
|
self._active_tensors[h] = TensorWeakRef(y) |
|
|
|
self._output_handles.update(ohandles) |
|
|
|
self._active_tensors.update([TensorWeakRef(o) for o in outputs]) |
|
|
|
return outputs |
|
|
|
|
|
|
|
def _apply_const(self, value, dtype, device): |
|
|
|
@@ -303,9 +303,11 @@ class trace: |
|
|
|
x.mixin_handle = h |
|
|
|
x.recording = True |
|
|
|
x._trace_mixin_info = info |
|
|
|
self._active_tensors[h] = TensorWeakRef(x) |
|
|
|
if self._symbolic: |
|
|
|
self._lazy_eval_tensors[h] = TensorWeakRef(x) |
|
|
|
|
|
|
|
self._seq.append((op, tuple(ihandles), tuple(ohandles))) |
|
|
|
self._active_tensors.update([TensorWeakRef(o) for o in outputs]) |
|
|
|
|
|
|
|
def _record_const(self, outputs): |
|
|
|
if skip_tracing: |
|
|
|
@@ -327,6 +329,8 @@ class trace: |
|
|
|
x.mixin_handle = h |
|
|
|
x.recording = True |
|
|
|
x._trace_mixin_info = info |
|
|
|
if self._symbolic: |
|
|
|
self._lazy_eval_tensors[h] = TensorWeakRef(x) |
|
|
|
self._seq.append(("Const", tuple(), tuple(ohandles))) |
|
|
|
|
|
|
|
def _set_active(self, active: bool): |
|
|
|
@@ -346,12 +350,12 @@ class trace: |
|
|
|
self._lazy_eval_links = () |
|
|
|
|
|
|
|
def _take_escaped_tensors(self): |
|
|
|
escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors)) |
|
|
|
escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors.values())) |
|
|
|
self._active_tensors.clear() |
|
|
|
return escaped_tensors |
|
|
|
|
|
|
|
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): |
|
|
|
lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors)) |
|
|
|
lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors.values())) |
|
|
|
readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] |
|
|
|
self._apply_graph_options(lazy_eval_graph) |
|
|
|
# FIXME |
|
|
|
@@ -401,7 +405,7 @@ class trace: |
|
|
|
# eval lazy eval tensors |
|
|
|
self._lazy_eval( |
|
|
|
self._lazy_eval_graph, |
|
|
|
tuple(self._lazy_eval_tensors), |
|
|
|
self._lazy_eval_tensors, |
|
|
|
self._lazy_eval_links, |
|
|
|
) |
|
|
|
self._lazy_eval_graph = None |
|
|
|
@@ -433,9 +437,10 @@ class trace: |
|
|
|
if not self._untraced and self._pc != len(self._seq): |
|
|
|
raise TraceMismatchError("premature end") |
|
|
|
if not self._symbolic or not self._untraced: |
|
|
|
for x in self._active_tensors: |
|
|
|
for x in self._active_tensors.values(): |
|
|
|
if x() is not None: |
|
|
|
x()._dev_tensor() |
|
|
|
x()._reset_varnode() |
|
|
|
x().mixin_handle = -1 |
|
|
|
x().recording = False |
|
|
|
|
|
|
|
@@ -459,7 +464,7 @@ class trace: |
|
|
|
if self._untraced: |
|
|
|
# conditionally reading a compiled tensor in excluded region |
|
|
|
# is permitted, so we have to assume every tensor might be read |
|
|
|
for x in self._active_tensors: |
|
|
|
for x in self._active_tensors.values(): |
|
|
|
info = self._tinfo[x().mixin_handle] |
|
|
|
info.exported = True |
|
|
|
info.data_read = True |
|
|
|
@@ -626,8 +631,20 @@ class trace: |
|
|
|
if self._capture_as_const: |
|
|
|
self._process_inputs(*args, **kwargs) |
|
|
|
outputs = self.__wrapped__(*args, **kwargs) |
|
|
|
transform = False |
|
|
|
if outputs is not None: |
|
|
|
if not isinstance(outputs, collections.abc.Sequence): |
|
|
|
transform = True |
|
|
|
outputs = (outputs,) |
|
|
|
for o in outputs: |
|
|
|
if o._copied: |
|
|
|
self._active_tensors[o.mixin_handle] = TensorWeakRef(o) |
|
|
|
if self._untraced and self._symbolic: |
|
|
|
self._lazy_eval_tensors[o.mixin_handle] = TensorWeakRef(o) |
|
|
|
if self._capture_as_const: |
|
|
|
self._process_outputs(outputs) |
|
|
|
if transform: |
|
|
|
outputs = outputs[0] |
|
|
|
return outputs |
|
|
|
|
|
|
|
def dump( |
|
|
|
@@ -1031,7 +1048,6 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): |
|
|
|
if require_links: |
|
|
|
active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),) |
|
|
|
|
|
|
|
active_trace._lazy_eval_tensors.update([TensorWeakRef(o) for o in outputs]) |
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
|
|
@@ -1042,7 +1058,6 @@ def apply_const_symbolic_mode(value, dtype, device): |
|
|
|
ret = RawTensor(graph.make_const(value, dtype=dtype, device=device)) |
|
|
|
if np.array(value).ndim == 0: |
|
|
|
setscalar(ret) |
|
|
|
active_trace._lazy_eval_tensors.add(TensorWeakRef(ret)) |
|
|
|
return (ret,) |
|
|
|
|
|
|
|
|
|
|
|
|