| @@ -124,7 +124,8 @@ class trace: | |||||
| self._graph = None | self._graph = None | ||||
| self._need_reset_nodes = None | self._need_reset_nodes = None | ||||
| self._lazy_eval_graph = None | self._lazy_eval_graph = None | ||||
| self._lazy_eval_tensors = weakref.WeakSet() | |||||
| self._lazy_eval_tensors = [] | |||||
| self._lazy_eval_tensor_count = 0 | |||||
| self._active_tensors = weakref.WeakSet() | self._active_tensors = weakref.WeakSet() | ||||
| self._tensor_remaps = None | self._tensor_remaps = None | ||||
| self._inputs_to_restore = None | self._inputs_to_restore = None | ||||
| @@ -283,12 +284,18 @@ class trace: | |||||
| x._TraceMixin__restore() | x._TraceMixin__restore() | ||||
| if self._symbolic: | if self._symbolic: | ||||
| # eval lazy eval tensors | # eval lazy eval tensors | ||||
| lazy_eval_tensors = tuple(self._lazy_eval_tensors) | |||||
| if lazy_eval_tensors: | |||||
| readers = [ | |||||
| G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] | |||||
| for x in lazy_eval_tensors | |||||
| ] | |||||
| if self._lazy_eval_tensors: | |||||
| lazy_eval_tensors = [] | |||||
| visited = set() | |||||
| readers = [] | |||||
| for x in self._lazy_eval_tensors: | |||||
| x = x() | |||||
| if x is None or x in visited: | |||||
| continue | |||||
| reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] | |||||
| readers.append(reader) | |||||
| lazy_eval_tensors.append(x) | |||||
| visited.add(x) | |||||
| self._apply_graph_options(self._lazy_eval_graph) | self._apply_graph_options(self._lazy_eval_graph) | ||||
| self._lazy_eval_graph.compile(*readers) | self._lazy_eval_graph.compile(*readers) | ||||
| self._lazy_eval_graph() | self._lazy_eval_graph() | ||||
| @@ -844,7 +851,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||||
| ] | ] | ||||
| ovars = apply(op, *ivars) | ovars = apply(op, *ivars) | ||||
| outputs = [LazyEvalTensor(v) for v in ovars] | outputs = [LazyEvalTensor(v) for v in ovars] | ||||
| active_trace._lazy_eval_tensors.update(outputs) | |||||
| active_trace._lazy_eval_tensors.extend(weakref.ref(oup) for oup in outputs) | |||||
| return outputs | return outputs | ||||
| @@ -855,7 +862,7 @@ apply.disable(apply_symbolic_mode) | |||||
| def apply_const_symbolic_mode(op: Const, *args: RawTensor): | def apply_const_symbolic_mode(op: Const, *args: RawTensor): | ||||
| graph = active_trace._lazy_eval_graph | graph = active_trace._lazy_eval_graph | ||||
| ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) | ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) | ||||
| active_trace._lazy_eval_tensors.add(ret) | |||||
| active_trace._lazy_eval_tensors.append(weakref.ref(ret)) | |||||
| return (ret,) | return (ret,) | ||||