|
|
|
@@ -191,19 +191,20 @@ class trace: |
|
|
|
if len(ihandles) != len(args): |
|
|
|
raise TraceMismatchError("op input size different from last time") |
|
|
|
|
|
|
|
# check all inputs of crrent op |
|
|
|
for h, x in zip(ihandles, args): |
|
|
|
info = self._tinfo[h] |
|
|
|
if info.external: |
|
|
|
if ( |
|
|
|
x.__class__ is CompiledTensorProxy |
|
|
|
and not self._tinfo[x._CompiledTensorProxy__handle].exported |
|
|
|
x._compiled_info is not None |
|
|
|
and not self._tinfo[x._mixin_handle].exported |
|
|
|
): |
|
|
|
raise TraceMismatchError( |
|
|
|
"failed to capture: input was an external tensor " |
|
|
|
"last time, got an internal tensor this time" |
|
|
|
) |
|
|
|
if info.bound_data: |
|
|
|
if x.__class__ is CompiledTensorProxy: |
|
|
|
if x._compiled_info is not None: |
|
|
|
raise TraceMismatchError( |
|
|
|
"const capture violated: was an external tensor " |
|
|
|
"last time, got an internal tensor this time" |
|
|
|
@@ -225,17 +226,17 @@ class trace: |
|
|
|
) |
|
|
|
info.data_setter.set_value(x._dev_tensor()) |
|
|
|
else: |
|
|
|
if x.mixin_handle == -1: |
|
|
|
if x._mixin_handle == -1: |
|
|
|
if x._handle not in self._tensor_remaps: |
|
|
|
raise TraceMismatchError( |
|
|
|
"unexpected capture: trying to use an external tensor as " |
|
|
|
"input, but that input was an internal tensor last time" |
|
|
|
) |
|
|
|
else: |
|
|
|
x.mixin_handle = self._tensor_remaps[ |
|
|
|
x._mixin_handle = self._tensor_remaps[ |
|
|
|
x._handle |
|
|
|
]._CompiledTensorProxy__handle |
|
|
|
if x.mixin_handle != h: |
|
|
|
if x._mixin_handle != h: |
|
|
|
raise TraceMismatchError( |
|
|
|
"mis-wiring: input edge to an data flow " |
|
|
|
"graph node is different from last time" |
|
|
|
@@ -245,9 +246,10 @@ class trace: |
|
|
|
outputs = [] |
|
|
|
for h in ohandles: |
|
|
|
info = self._tinfo[h] |
|
|
|
# generate output tensor and create compied info |
|
|
|
y = RawTensor(info.varnode) |
|
|
|
y._compiled_info = CompiledTensorProxy(h) |
|
|
|
y.mixin_handle = h |
|
|
|
y._mixin_handle = h |
|
|
|
outputs += [y] |
|
|
|
self._active_tensors[h] = TensorWeakRef(y) |
|
|
|
self._output_handles.update(ohandles) |
|
|
|
@@ -260,6 +262,7 @@ class trace: |
|
|
|
raise TraceMismatchError("trace should end here, but more op observed") |
|
|
|
record = self._seq[self._pc] |
|
|
|
op_, ihandles, ohandles = record |
|
|
|
# Const op is represented by a str |
|
|
|
assert isinstance(op_, str) and op_ == "Const" |
|
|
|
|
|
|
|
eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy()) |
|
|
|
@@ -273,17 +276,18 @@ class trace: |
|
|
|
outputs = [self._tinfo[h].bound_data] |
|
|
|
return outputs |
|
|
|
|
|
|
|
# run in first step, record information for trace |
|
|
|
def _record_op(self, op, inputs, outputs): |
|
|
|
if skip_tracing: |
|
|
|
for x in inputs: |
|
|
|
h = getattr(x, "mixin_handle", -1) |
|
|
|
h = getattr(x, "_mixin_handle", -1) |
|
|
|
if h >= 0: |
|
|
|
self._tinfo[h].data = True |
|
|
|
return |
|
|
|
|
|
|
|
ihandles = [] |
|
|
|
for x in inputs: |
|
|
|
h = getattr(x, "mixin_handle", -1) |
|
|
|
h = getattr(x, "_mixin_handle", -1) |
|
|
|
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): |
|
|
|
h, info = self._new_handle() |
|
|
|
info.external = True |
|
|
|
@@ -300,8 +304,8 @@ class trace: |
|
|
|
h, info = self._new_handle() |
|
|
|
ohandles.append(h) |
|
|
|
info.external = False |
|
|
|
x.mixin_handle = h |
|
|
|
x.recording = True |
|
|
|
x._mixin_handle = h |
|
|
|
x._recording = True |
|
|
|
x._trace_mixin_info = info |
|
|
|
self._active_tensors[h] = TensorWeakRef(x) |
|
|
|
if self._symbolic: |
|
|
|
@@ -312,7 +316,7 @@ class trace: |
|
|
|
def _record_const(self, outputs): |
|
|
|
if skip_tracing: |
|
|
|
(x,) = outputs |
|
|
|
h = getattr(x, "mixin_handle", -1) |
|
|
|
h = getattr(x, "_mixin_handle", -1) |
|
|
|
if h >= 0: |
|
|
|
self._tinfo[h].data_read = True |
|
|
|
return |
|
|
|
@@ -326,8 +330,8 @@ class trace: |
|
|
|
info.shape = x.shape |
|
|
|
info.bound_data = x |
|
|
|
info.is_const = True |
|
|
|
x.mixin_handle = h |
|
|
|
x.recording = True |
|
|
|
x._mixin_handle = h |
|
|
|
x._recording = True |
|
|
|
x._trace_mixin_info = info |
|
|
|
if self._symbolic: |
|
|
|
self._lazy_eval_tensors[h] = TensorWeakRef(x) |
|
|
|
@@ -371,6 +375,7 @@ class trace: |
|
|
|
lazy_eval_graph.compile(*lazy_eval_links, *readers) |
|
|
|
lazy_eval_graph() |
|
|
|
for r, x in zip(readers, lazy_eval_tensors): |
|
|
|
# get values from lazy_eval_graph and assign to lazy_eval tensor |
|
|
|
x()._handle = RawTensor(r.op.get_value())._handle |
|
|
|
x()._reset_varnode() |
|
|
|
|
|
|
|
@@ -395,14 +400,14 @@ class trace: |
|
|
|
if self._untraced: |
|
|
|
for x in escaped_tensors: |
|
|
|
if x(): |
|
|
|
info = self._tinfo[x().mixin_handle] |
|
|
|
info = self._tinfo[x()._mixin_handle] |
|
|
|
info.data_read = True |
|
|
|
x().mixin_handle = -1 |
|
|
|
x().recording = False |
|
|
|
x()._mixin_handle = -1 |
|
|
|
x()._recording = False |
|
|
|
if self._inputs_to_restore: |
|
|
|
for x in self._inputs_to_restore: |
|
|
|
x.mixin_handle = -1 |
|
|
|
x.recording = False |
|
|
|
x._mixin_handle = -1 |
|
|
|
x._recording = False |
|
|
|
if self._symbolic and ( |
|
|
|
self._lazy_eval_tensors or self._lazy_eval_links |
|
|
|
): |
|
|
|
@@ -441,12 +446,13 @@ class trace: |
|
|
|
if not self._untraced and self._pc != len(self._seq): |
|
|
|
raise TraceMismatchError("premature end") |
|
|
|
if not self._symbolic or not self._untraced: |
|
|
|
# reset output tensors |
|
|
|
for x in self._active_tensors.values(): |
|
|
|
if x() is not None: |
|
|
|
x()._dev_tensor() |
|
|
|
x()._reset_varnode() |
|
|
|
x().mixin_handle = -1 |
|
|
|
x().recording = False |
|
|
|
x()._mixin_handle = -1 |
|
|
|
x()._recording = False |
|
|
|
x()._trace_mixin_info = None |
|
|
|
|
|
|
|
try: |
|
|
|
@@ -470,10 +476,14 @@ class trace: |
|
|
|
# conditionally reading a compiled tensor in excluded region |
|
|
|
# is permitted, so we have to assume every tensor might be read |
|
|
|
for x in self._active_tensors.values(): |
|
|
|
info = self._tinfo[x().mixin_handle] |
|
|
|
info.exported = True |
|
|
|
info.data_read = True |
|
|
|
x()._dev_tensor() |
|
|
|
if x(): |
|
|
|
info = self._tinfo[x()._mixin_handle] |
|
|
|
info.exported = True |
|
|
|
info.data_read = True |
|
|
|
else: |
|
|
|
for x in self._active_tensors.values(): |
|
|
|
if x(): |
|
|
|
x()._dev_tensor() |
|
|
|
|
|
|
|
def _apply_graph_options(self, graph): |
|
|
|
|
|
|
|
@@ -528,7 +538,6 @@ class trace: |
|
|
|
info.varnode = opnode.outputs[0] |
|
|
|
in_out_links += opnode.outputs[1:] |
|
|
|
|
|
|
|
cnt_data, cnt_value, cnt_shape = 0, 0, 0 |
|
|
|
for op, ihandles, ohandles in self._seq: |
|
|
|
if isinstance(op, str) and op == "Const": |
|
|
|
assert len(ihandles) == 0 |
|
|
|
@@ -604,16 +613,13 @@ class trace: |
|
|
|
# Shape can be obtained from data so doesn't need its own |
|
|
|
# output node. On the other hand, value is read separately |
|
|
|
# to leverage eager h2d copy |
|
|
|
cnt_data += 1 |
|
|
|
info.shape_read = False |
|
|
|
opnode = info.data_reader = G.OutputNode(v, *in_out_links) |
|
|
|
add_reader(opnode) |
|
|
|
if info.value_read: |
|
|
|
cnt_value += 1 |
|
|
|
opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links) |
|
|
|
add_reader(opnode) |
|
|
|
if info.shape_read: |
|
|
|
cnt_shape += 1 |
|
|
|
opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links) |
|
|
|
add_reader(opnode) |
|
|
|
|
|
|
|
@@ -637,15 +643,17 @@ class trace: |
|
|
|
self._process_inputs(*args, **kwargs) |
|
|
|
outputs = self.__wrapped__(*args, **kwargs) |
|
|
|
transform = False |
|
|
|
# outputs can be None |
|
|
|
if outputs is not None: |
|
|
|
if not isinstance(outputs, collections.abc.Sequence): |
|
|
|
transform = True |
|
|
|
outputs = (outputs,) |
|
|
|
for o in outputs: |
|
|
|
# if outputs are copied, then use the newest info in trace data structure |
|
|
|
if o._copied: |
|
|
|
self._active_tensors[o.mixin_handle] = TensorWeakRef(o) |
|
|
|
self._active_tensors[o._mixin_handle] = TensorWeakRef(o) |
|
|
|
if self._untraced and self._symbolic: |
|
|
|
self._lazy_eval_tensors[o.mixin_handle] = TensorWeakRef(o) |
|
|
|
self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o) |
|
|
|
if self._capture_as_const: |
|
|
|
self._process_outputs(outputs) |
|
|
|
if transform: |
|
|
|
@@ -819,8 +827,8 @@ class trace: |
|
|
|
info.device = x.device |
|
|
|
info.dtype = x.dtype |
|
|
|
info.shape = x.numpy().shape |
|
|
|
x.mixin_handle = h |
|
|
|
x.recording = True |
|
|
|
x._mixin_handle = h |
|
|
|
x._recording = True |
|
|
|
x._trace_mixin_info = info |
|
|
|
self._inputs_to_restore.append(x) |
|
|
|
return h |
|
|
|
@@ -914,12 +922,12 @@ class trace: |
|
|
|
if not isinstance(x, RawTensor): |
|
|
|
raise TypeError("every item of return value should be tensor") |
|
|
|
if self._untraced: |
|
|
|
h = x.mixin_handle |
|
|
|
h = x._mixin_handle |
|
|
|
if h < 0: |
|
|
|
raise RuntimeError("output is not computed from inputs") |
|
|
|
self._output_bindings.append(h) |
|
|
|
else: |
|
|
|
h = x.mixin_handle |
|
|
|
h = x._mixin_handle |
|
|
|
if h not in self._output_handles: |
|
|
|
raise RuntimeError("output is not computed from inputs") |
|
|
|
if h != self._output_bindings[i]: |
|
|
|
@@ -938,6 +946,11 @@ class trace: |
|
|
|
raise RuntimeError("trace is not set with profiling=True") |
|
|
|
return json.loads(self._profiler.get()) |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
for x in self._tinfo: |
|
|
|
if getattr(x, "bound_data", None): |
|
|
|
x.bound_data = None |
|
|
|
|
|
|
|
def trace(self, *args, **kwargs): |
|
|
|
raise NotImplementedError( |
|
|
|
"trace is deemed unbeneficial with the new " |
|
|
|
|