GitOrigin-RevId: 4edc38eaf2
tags/v1.2.0
| @@ -20,4 +20,4 @@ class Const: | |||||
| def __call__(self, *reference): | def __call__(self, *reference): | ||||
| Wrapper = type(reference[0]) | Wrapper = type(reference[0]) | ||||
| return (Wrapper(self.value, self.dtype, self.device),) | |||||
| return (Wrapper(self.value, self.dtype, self.device, True),) | |||||
| @@ -19,10 +19,11 @@ import numpy as np | |||||
| from ...utils.comp_graph_tools import set_priority_to_id as _set_priority_to_id | from ...utils.comp_graph_tools import set_priority_to_id as _set_priority_to_id | ||||
| from .. import _imperative_rt | from .. import _imperative_rt | ||||
| from .._imperative_rt import GraphOptimizeOptions | from .._imperative_rt import GraphOptimizeOptions | ||||
| from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | |||||
| from .._imperative_rt.ops import BackwardGraph | from .._imperative_rt.ops import BackwardGraph | ||||
| from .._wrap import device as as_device | from .._wrap import device as as_device | ||||
| from ..ops.builtin import OpDef | from ..ops.builtin import OpDef | ||||
| from .core import OpBase, TensorBase, apply | |||||
| from .core import OpBase, TensorBase | |||||
| class Graph(_imperative_rt.ComputingGraph): | class Graph(_imperative_rt.ComputingGraph): | ||||
| @@ -269,9 +270,8 @@ def optimize_for_inference(dest_vars, **kwargs): | |||||
| if kwargs: | if kwargs: | ||||
| raise ValueError("unknown options: %s" % list(kwargs)) | raise ValueError("unknown options: %s" % list(kwargs)) | ||||
| res_vars = _imperative_rt.optimize_for_inference( | |||||
| [i._node for i in dest_vars], inference_options | |||||
| ) | |||||
| dest_vars = [var._node for var in dest_vars] | |||||
| res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options) | |||||
| return [VarNode(i) for i in res_vars] | return [VarNode(i) for i in res_vars] | ||||
| @@ -437,19 +437,25 @@ def _unwrap(x): | |||||
| return x | return x | ||||
| @apply.register() | |||||
| def _(op: OpDef, *args: VarNode): | |||||
| def apply_normal_op(op: OpDef, *args: VarNode): | |||||
| outputs = _imperative_rt.invoke_op(op, _unwrap(args)) | outputs = _imperative_rt.invoke_op(op, _unwrap(args)) | ||||
| return _wrap(outputs) | return _wrap(outputs) | ||||
| @apply.register() | |||||
| def _(op: BackwardGraph, *args: VarNode): | |||||
| def apply_backward_varnode(op: BackwardGraph, *args: VarNode): | |||||
| assert args | assert args | ||||
| graph = args[0].graph | graph = args[0].graph | ||||
| return BackwardGraph.interpret( | |||||
| op, lambda op, args: apply(op, *args), graph._make_const_for_backward, args | |||||
| outputs = op.interpret( | |||||
| op, | |||||
| lambda op, args: apply_normal_op(op, *args), | |||||
| graph._make_const_for_backward, | |||||
| args, | |||||
| ) | ) | ||||
| outputs = [o._node if hasattr(o, "_node") else o for o in outputs] | |||||
| return outputs | |||||
| set_cpp_apply_backward_varnode(apply_backward_varnode) | |||||
| def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): | def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): | ||||
| @@ -6,5 +6,23 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from ..core._imperative_rt.core2 import ( | |||||
| set_cpp_apply_compiled_mode, | |||||
| set_cpp_apply_const_compiled_mode, | |||||
| set_cpp_apply_const_with_tracing, | |||||
| set_cpp_apply_with_tracing, | |||||
| ) | |||||
| from .sublinear_memory_config import SublinearMemoryConfig | from .sublinear_memory_config import SublinearMemoryConfig | ||||
| from .tracing import exclude_from_trace, trace | |||||
| from .tracing import ( | |||||
| apply_compiled_mode, | |||||
| apply_const_compiled_mode, | |||||
| apply_const_with_tracing, | |||||
| apply_with_tracing, | |||||
| exclude_from_trace, | |||||
| trace, | |||||
| ) | |||||
| set_cpp_apply_with_tracing(apply_with_tracing) | |||||
| set_cpp_apply_const_with_tracing(apply_const_with_tracing) | |||||
| set_cpp_apply_compiled_mode(apply_compiled_mode) | |||||
| set_cpp_apply_const_compiled_mode(apply_const_compiled_mode) | |||||
| @@ -18,8 +18,20 @@ import weakref | |||||
| import numpy as np | import numpy as np | ||||
| from ..core._imperative_rt import GraphProfiler | |||||
| from ..core._imperative_rt.core2 import Tensor | |||||
| from ..core._imperative_rt import GraphProfiler, common, put | |||||
| from ..core._imperative_rt.core2 import Tensor as RawTensor | |||||
| from ..core._imperative_rt.core2 import ( | |||||
| TensorWeakRef, | |||||
| apply, | |||||
| call_level, | |||||
| set_compiled, | |||||
| set_symbolic, | |||||
| set_tracing, | |||||
| skip_tracing, | |||||
| unset_compiled, | |||||
| unset_symbolic, | |||||
| unset_tracing, | |||||
| ) | |||||
| from ..core._imperative_rt.ops import ( | from ..core._imperative_rt.ops import ( | ||||
| CollectiveComm, | CollectiveComm, | ||||
| GaussianRNG, | GaussianRNG, | ||||
| @@ -29,10 +41,9 @@ from ..core._imperative_rt.ops import ( | |||||
| ) | ) | ||||
| from ..core._trace_option import set_symbolic_shape | from ..core._trace_option import set_symbolic_shape | ||||
| from ..core._wrap import device as as_device | from ..core._wrap import device as as_device | ||||
| from ..core.ops.builtin import OpDef | |||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
| from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||||
| from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor | |||||
| from .sublinear_memory_config import SublinearMemoryConfig | from .sublinear_memory_config import SublinearMemoryConfig | ||||
| @@ -45,7 +56,6 @@ class TraceMismatchError(RuntimeError): | |||||
| active_trace = None | active_trace = None | ||||
| skip_tracing = False | |||||
| def is_tracing(): | def is_tracing(): | ||||
| @@ -63,11 +73,13 @@ def exclude_from_trace(): | |||||
| return | return | ||||
| try: | try: | ||||
| skip_tracing = True | skip_tracing = True | ||||
| unset_tracing() | |||||
| if active_trace is not None: | if active_trace is not None: | ||||
| active_trace._begin_excluded_region() | active_trace._begin_excluded_region() | ||||
| yield | yield | ||||
| finally: | finally: | ||||
| skip_tracing = False | skip_tracing = False | ||||
| set_tracing() | |||||
| class TensorInfo: | class TensorInfo: | ||||
| @@ -75,9 +87,6 @@ class TensorInfo: | |||||
| # collected attributes | # collected attributes | ||||
| "external", | "external", | ||||
| "exported", | "exported", | ||||
| "data_read", | |||||
| "shape_read", | |||||
| "value_read", | |||||
| "device", | "device", | ||||
| "dtype", | "dtype", | ||||
| "shape", | "shape", | ||||
| @@ -93,9 +102,6 @@ class TensorInfo: | |||||
| def __init__(self): | def __init__(self): | ||||
| self.exported = None | self.exported = None | ||||
| self.data_read = None | |||||
| self.shape_read = None | |||||
| self.value_read = None | |||||
| self.bound_data = None | self.bound_data = None | ||||
| self.data_setter = None | self.data_setter = None | ||||
| @@ -147,6 +153,8 @@ class trace: | |||||
| self._profiler = None | self._profiler = None | ||||
| self._graph_opt_level = opt_level | self._graph_opt_level = opt_level | ||||
| self._symbolic_shape = symbolic_shape | self._symbolic_shape = symbolic_shape | ||||
| self._handle2tensors = {} | |||||
| self._handle2compiledtensors = {} | |||||
| self._reset() | self._reset() | ||||
| @@ -158,9 +166,9 @@ class trace: | |||||
| self._graph = None | self._graph = None | ||||
| self._need_reset_nodes = None | self._need_reset_nodes = None | ||||
| self._lazy_eval_graph = None | self._lazy_eval_graph = None | ||||
| self._lazy_eval_tensors = weakref.WeakSet() | |||||
| self._lazy_eval_tensors = set() | |||||
| self._lazy_eval_links = None | self._lazy_eval_links = None | ||||
| self._active_tensors = weakref.WeakSet() | |||||
| self._active_tensors = set() | |||||
| self._tensor_remaps = None | self._tensor_remaps = None | ||||
| self._inputs_to_restore = None | self._inputs_to_restore = None | ||||
| self._arg_bindings = None | self._arg_bindings = None | ||||
| @@ -220,66 +228,72 @@ class trace: | |||||
| ) | ) | ||||
| info.data_setter.set_value(x._dev_tensor()) | info.data_setter.set_value(x._dev_tensor()) | ||||
| else: | else: | ||||
| if x.__class__ is not CompiledTensorProxy: | |||||
| if x not in self._tensor_remaps: | |||||
| raise TraceMismatchError( | |||||
| "unexpected capture: trying to use an external tensor as " | |||||
| "input, but that input was an internal tensor last time" | |||||
| ) | |||||
| else: | |||||
| x = self._tensor_remaps[x] | |||||
| if x._CompiledTensorProxy__handle != h: | |||||
| raise TraceMismatchError( | |||||
| "mis-wiring: input edge to an data flow " | |||||
| "graph node is different from last time" | |||||
| ) | |||||
| pass | |||||
| # if x.__class__ is not CompiledTensorProxy: | |||||
| # if x not in self._tensor_remaps: | |||||
| # raise TraceMismatchError( | |||||
| # "unexpected capture: trying to use an external tensor as " | |||||
| # "input, but that input was an internal tensor last time" | |||||
| # ) | |||||
| # else: | |||||
| # x = self._tensor_remaps[x] | |||||
| # if x._CompiledTensorProxy__handle != h: | |||||
| # raise TraceMismatchError( | |||||
| # "mis-wiring: input edge to an data flow " | |||||
| # "graph node is different from last time" | |||||
| # ) | |||||
| self._pc += 1 | self._pc += 1 | ||||
| outputs = tuple([CompiledTensorProxy(h) for h in ohandles]) | |||||
| self._active_tensors.update(outputs) | |||||
| for h in ohandles: | |||||
| t = CompiledTensorProxy(h) | |||||
| t._dev_tensor() | |||||
| self._handle2compiledtensors[h] = t | |||||
| outputs = [self._handle2tensors[h] for h in ohandles] | |||||
| self._active_tensors.update([TensorWeakRef(o) for o in outputs]) | |||||
| return outputs | return outputs | ||||
| def _apply_const(self, op, args): | |||||
| def _apply_const(self, value, dtype, device): | |||||
| assert not self._untraced | assert not self._untraced | ||||
| # check against trace | # check against trace | ||||
| if self._pc >= len(self._seq): | if self._pc >= len(self._seq): | ||||
| raise TraceMismatchError("trace should end here, but more op observed") | raise TraceMismatchError("trace should end here, but more op observed") | ||||
| record = self._seq[self._pc] | record = self._seq[self._pc] | ||||
| op_, ihandles, ohandles = record | op_, ihandles, ohandles = record | ||||
| assert isinstance(op_, Const) | |||||
| eq = op_.value == op.value | |||||
| if not isinstance(eq, bool): | |||||
| eq = all(eq) | |||||
| if not eq: | |||||
| raise TraceMismatchError( | |||||
| "const tensor violated: got a different tensor this time" | |||||
| ) | |||||
| assert isinstance(op_, str) and op_ == "Const" | |||||
| # TODO : assert on const value | |||||
| # eq = value == self._tinfo[ohandles[0]].bound_data.numpy() | |||||
| # if not isinstance(eq, bool): | |||||
| # eq = all(eq) | |||||
| # if not eq: | |||||
| # raise TraceMismatchError( | |||||
| # "const tensor violated: got a different tensor this time" | |||||
| # ) | |||||
| self._pc += 1 | self._pc += 1 | ||||
| (h,) = ohandles | (h,) = ohandles | ||||
| outputs = tuple([self._tinfo[h].bound_data]) | |||||
| outputs = [self._tinfo[h].bound_data] | |||||
| return outputs | return outputs | ||||
| def _record_op(self, op, inputs, outputs): | def _record_op(self, op, inputs, outputs): | ||||
| if skip_tracing: | if skip_tracing: | ||||
| for x in inputs: | for x in inputs: | ||||
| h = getattr(x, "_TraceMixin__handle", None) | |||||
| if h is not None: | |||||
| self._tinfo[h].data_read = True | |||||
| h = getattr(x, "mixin_handle", -1) | |||||
| if h >= 0: | |||||
| x.data_read = True | |||||
| return | return | ||||
| ihandles = [] | ihandles = [] | ||||
| for x in inputs: | for x in inputs: | ||||
| h = getattr(x, "_TraceMixin__handle", None) | |||||
| if h is None or (not self._capture_as_const and self._tinfo[h].exported): | |||||
| h = getattr(x, "mixin_handle", -1) | |||||
| if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): | |||||
| h, info = self._new_handle() | h, info = self._new_handle() | ||||
| info.external = True | info.external = True | ||||
| info.device = x.device | info.device = x.device | ||||
| info.dtype = x.dtype | info.dtype = x.dtype | ||||
| info.shape = x.shape | info.shape = x.shape | ||||
| if self._capture_as_const: | if self._capture_as_const: | ||||
| info.bound_data = x | |||||
| info.bound_data = RawTensor(x.numpy(), x.dtype, x.device, False) | |||||
| ihandles.append(h) | ihandles.append(h) | ||||
| @@ -288,17 +302,18 @@ class trace: | |||||
| h, info = self._new_handle() | h, info = self._new_handle() | ||||
| ohandles.append(h) | ohandles.append(h) | ||||
| info.external = False | info.external = False | ||||
| TraceMixin._TraceMixin__inject(x, h) | |||||
| x.mixin_handle = h | |||||
| self._handle2tensors[h] = x | |||||
| self._seq.append((op, tuple(ihandles), tuple(ohandles))) | self._seq.append((op, tuple(ihandles), tuple(ohandles))) | ||||
| self._active_tensors.update(outputs) | |||||
| self._active_tensors.update([TensorWeakRef(o) for o in outputs]) | |||||
| def _record_const(self, op, outputs): | |||||
| def _record_const(self, outputs): | |||||
| if skip_tracing: | if skip_tracing: | ||||
| (x,) = outputs | (x,) = outputs | ||||
| h = getattr(x, "_TraceMixin__handle", None) | |||||
| if h is not None: | |||||
| self._tinfo[h].data_read = True | |||||
| h = getattr(x, "mixin_handle", -1) | |||||
| if h >= 0: | |||||
| x.data_read = True | |||||
| return | return | ||||
| (x,) = outputs | (x,) = outputs | ||||
| @@ -310,8 +325,9 @@ class trace: | |||||
| info.shape = x.shape | info.shape = x.shape | ||||
| info.bound_data = x | info.bound_data = x | ||||
| info.is_const = True | info.is_const = True | ||||
| TraceMixin._TraceMixin__inject(x, h) | |||||
| self._seq.append((op, tuple(), tuple(ohandles))) | |||||
| x.mixin_handle = h | |||||
| self._handle2tensors[h] = x | |||||
| self._seq.append(("Const", tuple(), tuple(ohandles))) | |||||
| def _set_active(self, active: bool): | def _set_active(self, active: bool): | ||||
| global active_trace | global active_trace | ||||
| @@ -324,11 +340,8 @@ class trace: | |||||
| active_trace = None | active_trace = None | ||||
| def _init_trace(self, symbolic: bool): | def _init_trace(self, symbolic: bool): | ||||
| apply.enable(apply_with_tracing) | |||||
| apply.enable(apply_const_with_tracing) | |||||
| if symbolic: | if symbolic: | ||||
| apply.enable(apply_symbolic_mode) | |||||
| apply.enable(apply_const_symbolic_mode) | |||||
| set_symbolic() | |||||
| self._lazy_eval_graph = G.Graph() | self._lazy_eval_graph = G.Graph() | ||||
| self._apply_graph_options(self._lazy_eval_graph) | self._apply_graph_options(self._lazy_eval_graph) | ||||
| self._lazy_eval_links = () | self._lazy_eval_links = () | ||||
| @@ -339,10 +352,7 @@ class trace: | |||||
| return escaped_tensors | return escaped_tensors | ||||
| def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): | def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): | ||||
| readers = [ | |||||
| G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] | |||||
| for x in lazy_eval_tensors | |||||
| ] | |||||
| readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] | |||||
| self._apply_graph_options(lazy_eval_graph) | self._apply_graph_options(lazy_eval_graph) | ||||
| # FIXME | # FIXME | ||||
| if self._graph_opt_level is not None: | if self._graph_opt_level is not None: | ||||
| @@ -353,20 +363,22 @@ class trace: | |||||
| lazy_eval_graph.compile(*lazy_eval_links, *readers) | lazy_eval_graph.compile(*lazy_eval_links, *readers) | ||||
| lazy_eval_graph() | lazy_eval_graph() | ||||
| for r, x in zip(readers, lazy_eval_tensors): | for r, x in zip(readers, lazy_eval_tensors): | ||||
| assign_raw_tensor(x, as_raw_tensor(r.op.get_value())) | |||||
| x()._handle = RawTensor(r.op.get_value())._handle | |||||
| @contextlib.contextmanager | @contextlib.contextmanager | ||||
| def _setup(self): | def _setup(self): | ||||
| interrupted = False | interrupted = False | ||||
| def do_enter(): | def do_enter(): | ||||
| set_tracing() | |||||
| self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape) | self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape) | ||||
| self._set_active(True) | self._set_active(True) | ||||
| if self._untraced: | if self._untraced: | ||||
| self._init_trace(self._symbolic) | self._init_trace(self._symbolic) | ||||
| else: | else: | ||||
| apply.enable(apply_compiled_mode) | |||||
| apply.enable(apply_const_compiled_mode) | |||||
| # disable symbolic mode | |||||
| unset_symbolic() | |||||
| set_compiled() | |||||
| if self._graph is None: | if self._graph is None: | ||||
| self._compile() | self._compile() | ||||
| self._graph.execute() | self._graph.execute() | ||||
| @@ -375,12 +387,12 @@ class trace: | |||||
| escaped_tensors = self._take_escaped_tensors() | escaped_tensors = self._take_escaped_tensors() | ||||
| if self._untraced: | if self._untraced: | ||||
| for x in escaped_tensors: | for x in escaped_tensors: | ||||
| info = self._tinfo[x._TraceMixin__handle] | |||||
| info.data_read = True | |||||
| x._TraceMixin__restore() | |||||
| info = self._tinfo[x().mixin_handle] | |||||
| x().data_read = True | |||||
| x().mixin_handle = -1 | |||||
| if self._inputs_to_restore: | if self._inputs_to_restore: | ||||
| for x in self._inputs_to_restore: | for x in self._inputs_to_restore: | ||||
| x._TraceMixin__restore() | |||||
| x.mixin_handle = -1 | |||||
| if self._symbolic and ( | if self._symbolic and ( | ||||
| self._lazy_eval_tensors or self._lazy_eval_links | self._lazy_eval_tensors or self._lazy_eval_links | ||||
| ): | ): | ||||
| @@ -399,7 +411,7 @@ class trace: | |||||
| if self._pc == len(self._seq): | if self._pc == len(self._seq): | ||||
| for x in escaped_tensors: | for x in escaped_tensors: | ||||
| try: | try: | ||||
| assign_raw_tensor(x, as_raw_tensor(x._dev_tensor())) | |||||
| assign_raw_tensor(x(), RawTensor(x()._dev_tensor())) | |||||
| except TraceMismatchError: | except TraceMismatchError: | ||||
| # TraceMismatchError thrown in do_exit | # TraceMismatchError thrown in do_exit | ||||
| pass | pass | ||||
| @@ -409,22 +421,20 @@ class trace: | |||||
| # reset status | # reset status | ||||
| self._pc = 0 | self._pc = 0 | ||||
| self._tensor_remaps = None | self._tensor_remaps = None | ||||
| apply.disable(apply_with_tracing) | |||||
| apply.disable(apply_const_with_tracing) | |||||
| apply.disable(apply_symbolic_mode) | |||||
| apply.disable(apply_const_symbolic_mode) | |||||
| apply.disable(apply_compiled_mode) | |||||
| apply.disable(apply_const_compiled_mode) | |||||
| self._set_active(False) | self._set_active(False) | ||||
| # Restore global variable | |||||
| set_symbolic_shape(self._save_symbolic_shape) | set_symbolic_shape(self._save_symbolic_shape) | ||||
| unset_compiled() | |||||
| unset_symbolic() | |||||
| unset_tracing() | |||||
| def do_exit(): | def do_exit(): | ||||
| unset_tracing() | |||||
| if not self._untraced and self._pc != len(self._seq): | if not self._untraced and self._pc != len(self._seq): | ||||
| raise TraceMismatchError("premature end") | raise TraceMismatchError("premature end") | ||||
| if not self._symbolic or not self._untraced: | if not self._symbolic or not self._untraced: | ||||
| for x in self._active_tensors: | for x in self._active_tensors: | ||||
| x._dev_tensor() | |||||
| x()._dev_tensor() | |||||
| x().mixin_handle = -1 | |||||
| try: | try: | ||||
| do_enter() | do_enter() | ||||
| @@ -447,9 +457,9 @@ class trace: | |||||
| # conditionally reading a compiled tensor in excluded region | # conditionally reading a compiled tensor in excluded region | ||||
| # is permitted, so we have to assume every tensor might be read | # is permitted, so we have to assume every tensor might be read | ||||
| for x in self._active_tensors: | for x in self._active_tensors: | ||||
| info = self._tinfo[x._TraceMixin__handle] | |||||
| info = self._tinfo[x().mixin_handle] | |||||
| info.exported = True | info.exported = True | ||||
| info.data_read = True | |||||
| x().data_read = True | |||||
| def _apply_graph_options(self, graph): | def _apply_graph_options(self, graph): | ||||
| @@ -503,7 +513,7 @@ class trace: | |||||
| in_out_links += opnode.outputs[1:] | in_out_links += opnode.outputs[1:] | ||||
| for op, ihandles, ohandles in self._seq: | for op, ihandles, ohandles in self._seq: | ||||
| if isinstance(op, Const): | |||||
| if isinstance(op, str) and op == "Const": | |||||
| assert len(ihandles) == 0 | assert len(ihandles) == 0 | ||||
| (h,) = ohandles | (h,) = ohandles | ||||
| info = self._tinfo[h] | info = self._tinfo[h] | ||||
| @@ -554,7 +564,10 @@ class trace: | |||||
| io_links = (info.varnode,) | io_links = (info.varnode,) | ||||
| ivars.append(info.varnode) | ivars.append(info.varnode) | ||||
| ivars = [RawTensor(ivar) for ivar in ivars] | |||||
| ovars = apply(op, *ivars) | ovars = apply(op, *ivars) | ||||
| ovars = [x._varnode for x in ovars] | |||||
| if require_links and len(ovars) > 0: | if require_links and len(ovars) > 0: | ||||
| io_links = (ovars[0],) | io_links = (ovars[0],) | ||||
| assert len(ovars) == len(ohandles) | assert len(ovars) == len(ohandles) | ||||
| @@ -568,7 +581,8 @@ class trace: | |||||
| readers.append(opnode.outputs[0]) | readers.append(opnode.outputs[0]) | ||||
| in_out_links = opnode.outputs | in_out_links = opnode.outputs | ||||
| if info.data_read: | |||||
| x = self._handle2tensors[h] | |||||
| if x.data_read: | |||||
| # Shape can be obtained from data so doesn't need its own | # Shape can be obtained from data so doesn't need its own | ||||
| # output node. On the other hand, value is read separately | # output node. On the other hand, value is read separately | ||||
| # to leverage eager h2d copy | # to leverage eager h2d copy | ||||
| @@ -581,6 +595,7 @@ class trace: | |||||
| if info.shape_read: | if info.shape_read: | ||||
| opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links) | opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links) | ||||
| add_reader(opnode) | add_reader(opnode) | ||||
| # FIXME | # FIXME | ||||
| if self._graph_opt_level is not None: | if self._graph_opt_level is not None: | ||||
| graph.options.graph_opt_level = self._graph_opt_level | graph.options.graph_opt_level = self._graph_opt_level | ||||
| @@ -593,18 +608,6 @@ class trace: | |||||
| for opnode in self._need_reset_nodes: | for opnode in self._need_reset_nodes: | ||||
| opnode.reset() | opnode.reset() | ||||
| def _require_shape(self, handle): | |||||
| info = self._tinfo[handle] | |||||
| info.shape_read = True | |||||
| def _require_value(self, handle): | |||||
| info = self._tinfo[handle] | |||||
| info.value_read = True | |||||
| def _require_data(self, handle): | |||||
| info = self._tinfo[handle] | |||||
| info.data_read = True | |||||
| def __call__(self, *args, **kwargs): | def __call__(self, *args, **kwargs): | ||||
| if is_tracing(): | if is_tracing(): | ||||
| return self.__wrapped__(*args, **kwargs) | return self.__wrapped__(*args, **kwargs) | ||||
| @@ -728,8 +731,9 @@ class trace: | |||||
| dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k | dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k | ||||
| ) | ) | ||||
| set_tracing() | |||||
| for op, ihandles, ohandles in self._seq: | for op, ihandles, ohandles in self._seq: | ||||
| if isinstance(op, Const): | |||||
| if isinstance(op, str) and op == "Const": | |||||
| assert len(ihandles) == 0 | assert len(ihandles) == 0 | ||||
| (h,) = ohandles | (h,) = ohandles | ||||
| info = self._tinfo[h] | info = self._tinfo[h] | ||||
| @@ -750,7 +754,9 @@ class trace: | |||||
| info.bound_data.numpy(), dtype=info.dtype, device=dumped_device | info.bound_data.numpy(), dtype=info.dtype, device=dumped_device | ||||
| ) | ) | ||||
| ivars.append(h2v[h]) | ivars.append(h2v[h]) | ||||
| ivars = [RawTensor(ivar) for ivar in ivars] | |||||
| ovars = apply(op, *ivars) | ovars = apply(op, *ivars) | ||||
| ovars = [x._varnode for x in ovars] | |||||
| assert len(ovars) == len(ohandles) | assert len(ovars) == len(ohandles) | ||||
| h2v.update(zip(ohandles, ovars)) | h2v.update(zip(ohandles, ovars)) | ||||
| @@ -761,6 +767,7 @@ class trace: | |||||
| v.name = output_names[i] | v.name = output_names[i] | ||||
| dest_vars.append(v) | dest_vars.append(v) | ||||
| dest_vars = [G.VarNode(var) for var in dest_vars] | |||||
| if optimize_for_inference: | if optimize_for_inference: | ||||
| dest_vars = G.optimize_for_inference(dest_vars, **kwargs) | dest_vars = G.optimize_for_inference(dest_vars, **kwargs) | ||||
| @@ -782,15 +789,15 @@ class trace: | |||||
| info.external = False | info.external = False | ||||
| info.device = x.device | info.device = x.device | ||||
| info.dtype = x.dtype | info.dtype = x.dtype | ||||
| info.shape = x.shape | |||||
| TraceMixin._TraceMixin__inject(x, h) | |||||
| info.shape = x.numpy().shape | |||||
| x.mixin_handle = h | |||||
| self._handle2tensors[h] = x | |||||
| self._inputs_to_restore.append(x) | self._inputs_to_restore.append(x) | ||||
| return h | return h | ||||
| self._arg_bindings = [] | self._arg_bindings = [] | ||||
| for i, x in enumerate(args): | for i, x in enumerate(args): | ||||
| x = find_raw_tensor(x) | |||||
| if x is None: | |||||
| if not isinstance(x, RawTensor): | |||||
| raise TypeError( | raise TypeError( | ||||
| "positional arguments should all be tensor " | "positional arguments should all be tensor " | ||||
| "but args[%d] cannot be recognized as one" % i | "but args[%d] cannot be recognized as one" % i | ||||
| @@ -799,8 +806,7 @@ class trace: | |||||
| self._kwarg_bindings = {} | self._kwarg_bindings = {} | ||||
| for k, x in kwargs.items(): | for k, x in kwargs.items(): | ||||
| x = find_raw_tensor(x) | |||||
| if x is not None: | |||||
| if isinstance(x, RawTensor): | |||||
| self._kwarg_bindings[k] = record_input(x) | self._kwarg_bindings[k] = record_input(x) | ||||
| else: | else: | ||||
| if len(args) != len(self._arg_bindings): | if len(args) != len(self._arg_bindings): | ||||
| @@ -809,8 +815,7 @@ class trace: | |||||
| self._tensor_remaps = {} | self._tensor_remaps = {} | ||||
| for i, (h, x) in enumerate(zip(self._arg_bindings, args)): | for i, (h, x) in enumerate(zip(self._arg_bindings, args)): | ||||
| x = find_raw_tensor(x) | |||||
| if x is None: | |||||
| if not isinstance(x, RawTensor): | |||||
| raise TypeError( | raise TypeError( | ||||
| "positional arguments should all be tensor " | "positional arguments should all be tensor " | ||||
| "but args[%d] cannot be recognized as one" % i | "but args[%d] cannot be recognized as one" % i | ||||
| @@ -825,8 +830,7 @@ class trace: | |||||
| kwargs_tensors = {} | kwargs_tensors = {} | ||||
| for k, x in kwargs.items(): | for k, x in kwargs.items(): | ||||
| x = find_raw_tensor(x) | |||||
| if x is not None: | |||||
| if isinstance(x, RawTensor): | |||||
| kwargs_tensors[k] = x | kwargs_tensors[k] = x | ||||
| if set(kwargs_tensors) != set(self._kwarg_bindings): | if set(kwargs_tensors) != set(self._kwarg_bindings): | ||||
| too_many = set(kwargs_tensors) - set(self._kwarg_bindings) | too_many = set(kwargs_tensors) - set(self._kwarg_bindings) | ||||
| @@ -877,18 +881,17 @@ class trace: | |||||
| self._output_bindings = [] | self._output_bindings = [] | ||||
| for i, x in enumerate(outputs): | for i, x in enumerate(outputs): | ||||
| x = find_raw_tensor(x) | |||||
| if x is None: | |||||
| if not isinstance(x, RawTensor): | |||||
| raise TypeError("every item of return value should be tensor") | raise TypeError("every item of return value should be tensor") | ||||
| if self._untraced: | if self._untraced: | ||||
| if not isinstance(x, TraceMixin): | |||||
| h = x.mixin_handle | |||||
| if h < 0: | |||||
| raise RuntimeError("output is not computed from inputs") | raise RuntimeError("output is not computed from inputs") | ||||
| h = x._TraceMixin__handle | |||||
| self._output_bindings.append(h) | self._output_bindings.append(h) | ||||
| else: | else: | ||||
| if not isinstance(x, CompiledTensorProxy): | |||||
| h = x.mixin_handle | |||||
| if h not in self._handle2compiledtensors: | |||||
| raise RuntimeError("output is not computed from inputs") | raise RuntimeError("output is not computed from inputs") | ||||
| h = x._CompiledTensorProxy__handle | |||||
| if h != self._output_bindings[i]: | if h != self._output_bindings[i]: | ||||
| raise TraceMismatchError( | raise TraceMismatchError( | ||||
| "retval[%s] is a different tensor than last time" | "retval[%s] is a different tensor than last time" | ||||
| @@ -912,7 +915,7 @@ class trace: | |||||
| ) | ) | ||||
| class CompiledTensorProxy(RawTensor): | |||||
| class CompiledTensorProxy: | |||||
| """ | """ | ||||
| Duck-typed RawTensor | Duck-typed RawTensor | ||||
| """ | """ | ||||
| @@ -924,6 +927,8 @@ class CompiledTensorProxy(RawTensor): | |||||
| self.__shape = None | self.__shape = None | ||||
| self.__data = None | self.__data = None | ||||
| self.__value = None | self.__value = None | ||||
| self.__tensor = active_trace._handle2tensors[handle] | |||||
| self.__tensor.mixin_handle = handle | |||||
| @property | @property | ||||
| def dtype(self): | def dtype(self): | ||||
| @@ -938,19 +943,19 @@ class CompiledTensorProxy(RawTensor): | |||||
| if self._isscalar: | if self._isscalar: | ||||
| return () | return () | ||||
| if self.__shape is None: | if self.__shape is None: | ||||
| if self.__info.shape_read: | |||||
| if self.__tensor.shape_read: | |||||
| self.__shape = self.__info.shape_reader.get_value().shape | self.__shape = self.__info.shape_reader.get_value().shape | ||||
| elif self.__info.data_read: | |||||
| self.__shape = self._dev_tensor().shape | |||||
| elif self.__tensor.data_read: | |||||
| self.__shape = self.__tensor._dev_tensor().shape | |||||
| else: | else: | ||||
| raise TraceMismatchError("shape of this tensor is not read in trace") | raise TraceMismatchError("shape of this tensor is not read in trace") | ||||
| return self.__shape | return self.__shape | ||||
| def numpy(self): | def numpy(self): | ||||
| if self.__value is None: | if self.__value is None: | ||||
| if self.__info.value_read: | |||||
| if self.__tensor.value_read: | |||||
| self.__value = self.__info.value_reader.get_value() | self.__value = self.__info.value_reader.get_value() | ||||
| elif self.__info.data_read: | |||||
| elif self.__tensor.data_read: | |||||
| self.__value = self._dev_tensor().numpy() | self.__value = self._dev_tensor().numpy() | ||||
| else: | else: | ||||
| raise TraceMismatchError("value of this tensor is not read in trace") | raise TraceMismatchError("value of this tensor is not read in trace") | ||||
| @@ -960,9 +965,11 @@ class CompiledTensorProxy(RawTensor): | |||||
| def _dev_tensor(self): | def _dev_tensor(self): | ||||
| if self.__data is None: | if self.__data is None: | ||||
| if not self.__info.data_read: | |||||
| if not self.__tensor.data_read: | |||||
| raise TraceMismatchError("raw data of this tensor is not read in trace") | raise TraceMismatchError("raw data of this tensor is not read in trace") | ||||
| self.__data = self.__info.data_reader.get_value() | self.__data = self.__info.data_reader.get_value() | ||||
| self.__tensor._reset(RawTensor(self.__data)) | |||||
| self.__tensor.mixin_handle = self.__handle | |||||
| return self.__data | return self.__data | ||||
| def _drop(self): | def _drop(self): | ||||
| @@ -975,132 +982,31 @@ class CompiledTensorProxy(RawTensor): | |||||
| return | return | ||||
| def __del__(self): | def __del__(self): | ||||
| if self.__info.shape_read and self.__shape is not None: | |||||
| if self.__tensor.shape_read and self.__shape is not None: | |||||
| self.__info.shape_reader.drop_value() | self.__info.shape_reader.drop_value() | ||||
| if self.__info.value_read and self.__value is not None: | |||||
| self.__info.value_reader.drop_value() | |||||
| if self.__info.data_read and self.__data is not None: | |||||
| # if self.__tensor.value_read and self.__value is not None: | |||||
| # self.__info.value_reader.drop_value() | |||||
| if self.__tensor.data_read and self.__data is not None: | |||||
| self.__info.data_reader.drop_value() | self.__info.data_reader.drop_value() | ||||
| class LazyEvalTensor(RawTensor): | |||||
| def __init__(self, varnode, isscalar=False): | |||||
| super().__init__() | |||||
| self.__varnode = varnode | |||||
| self._isscalar = isscalar | |||||
| @property | |||||
| def dtype(self): | |||||
| return self.__varnode.dtype | |||||
| @property | |||||
| def device(self): | |||||
| return self.__varnode.device | |||||
| @property | |||||
| def shape(self): | |||||
| if self._isscalar: | |||||
| return () | |||||
| return self.__varnode.shape | |||||
| def numpy(self): | |||||
| ret = self.__varnode.value | |||||
| if self._isscalar: | |||||
| ret = ret.squeeze() | |||||
| return ret | |||||
| def _drop(self): | |||||
| return | |||||
| def _swap_in(self): | |||||
| return | |||||
| def _swap_out(self): | |||||
| return | |||||
| def _dev_tensor(self): | |||||
| raise RuntimeError("cannot access data during symbolic tracing") | |||||
| class TraceMixin: | |||||
| __subclass_cache = {} | |||||
| def __inject(self, handle): | |||||
| cache = __class__.__subclass_cache | |||||
| cls = self.__class__ | |||||
| subcls = cache.get(cls) | |||||
| if subcls is None: | |||||
| subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {}) | |||||
| self.__class__ = subcls | |||||
| self.__handle = handle | |||||
| self.__cls = cls | |||||
| return self | |||||
| def __restore(self): | |||||
| cls = self.__cls | |||||
| del self.__handle | |||||
| del self.__cls | |||||
| self.__class__ = cls | |||||
| return self | |||||
| @property | |||||
| def shape(self): | |||||
| if not skip_tracing: | |||||
| active_trace._require_shape(self.__handle) | |||||
| return super().shape | |||||
| def numpy(self): | |||||
| if not skip_tracing: | |||||
| active_trace._require_value(self.__handle) | |||||
| return super().numpy() | |||||
| def _dev_tensor(self): | |||||
| if not skip_tracing: | |||||
| active_trace._require_data(self.__handle) | |||||
| return super()._dev_tensor() | |||||
| def _drop(self): | |||||
| return | |||||
| def _swap_in(self): | |||||
| return | |||||
| def _swap_out(self): | |||||
| return | |||||
| class TracedRawTensor(TraceMixin, RawTensor): | |||||
| pass | |||||
| class TracedLazyTensor(TraceMixin, LazyEvalTensor): | |||||
| pass | |||||
| def assign_raw_tensor(lhs, rhs): | def assign_raw_tensor(lhs, rhs): | ||||
| handle = rhs._handle | |||||
| # Keep isscalar of lhs | |||||
| isscalar = lhs._isscalar | |||||
| rhs.__dict__.clear() | |||||
| lhs.__dict__.clear() | |||||
| lhs.__class__ = RawTensor | |||||
| lhs.__init__(handle, isscalar=isscalar) | |||||
| lhs.__init__(rhs) | |||||
| # this hook turns RawTensor into LazyEvalTensor | |||||
| @apply.register() | |||||
| # this hook turns RawTensor into LazyEvalTensor(varnode) | |||||
| def apply_symbolic_mode(op: OpDef, *args: RawTensor): | def apply_symbolic_mode(op: OpDef, *args: RawTensor): | ||||
| graph = active_trace._lazy_eval_graph | graph = active_trace._lazy_eval_graph | ||||
| ivars = [] | ivars = [] | ||||
| for x in args: | for x in args: | ||||
| var = getattr(x, "_LazyEvalTensor__varnode", None) | |||||
| var = getattr(x, "_varnode", None) | |||||
| if var: | if var: | ||||
| ivars.append(var) | ivars.append(var) | ||||
| else: | else: | ||||
| data_setter = G.InputNode( | data_setter = G.InputNode( | ||||
| device=x.device, | device=x.device, | ||||
| dtype=x.dtype, | dtype=x.dtype, | ||||
| shape=x.shape or (1,), | |||||
| shape=x.numpy().shape or (1,), | |||||
| graph=graph, | graph=graph, | ||||
| use_static_shape=True, | use_static_shape=True, | ||||
| ) | ) | ||||
| @@ -1119,108 +1025,75 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||||
| ivars[0] = opnode.outputs[0] | ivars[0] = opnode.outputs[0] | ||||
| active_trace._lazy_eval_links = (ivars[0],) | active_trace._lazy_eval_links = (ivars[0],) | ||||
| ovars = apply(op, *ivars) | |||||
| ivars = [ | |||||
| RawTensor(ivar._node) if hasattr(ivar, "_node") else RawTensor(ivar) | |||||
| for ivar in ivars | |||||
| ] | |||||
| unset_symbolic() | |||||
| outputs = apply(op, *ivars) | |||||
| set_symbolic() | |||||
| if require_links: | if require_links: | ||||
| active_trace._lazy_eval_links = (ovars[0],) | |||||
| active_trace._lazy_eval_links = (outputs[0]._varnode,) | |||||
| outputs = [LazyEvalTensor(v) for v in ovars] | |||||
| active_trace._lazy_eval_tensors.update(outputs) | |||||
| active_trace._lazy_eval_tensors.update([TensorWeakRef(o) for o in outputs]) | |||||
| return outputs | return outputs | ||||
| apply.disable(apply_symbolic_mode) | |||||
| @apply.register() | |||||
| def apply_const_symbolic_mode(op: Const, *args: RawTensor): | |||||
| def apply_const_symbolic_mode(value, dtype, device): | |||||
| graph = active_trace._lazy_eval_graph | graph = active_trace._lazy_eval_graph | ||||
| ret = LazyEvalTensor( | |||||
| graph.make_const(op.value, dtype=op.dtype, device=op.device), isscalar=True | |||||
| ) | |||||
| active_trace._lazy_eval_tensors.add(ret) | |||||
| # don't need to unset tracing | |||||
| # because varnode construction will ignore tracing flag | |||||
| ret = RawTensor(graph.make_const(value, dtype=dtype, device=device)) | |||||
| active_trace._lazy_eval_tensors.add(TensorWeakRef(ret)) | |||||
| return (ret,) | return (ret,) | ||||
| apply.disable(apply_const_symbolic_mode) | |||||
| @apply.register() | |||||
| def apply_compiled_mode(op: OpDef, *args: RawTensor): | def apply_compiled_mode(op: OpDef, *args: RawTensor): | ||||
| if skip_tracing: | if skip_tracing: | ||||
| args = [ | args = [ | ||||
| as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | |||||
| RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | |||||
| for x in args | for x in args | ||||
| ] | ] | ||||
| return apply.super(op, *args) | |||||
| unset_tracing() | |||||
| ret = apply(op, *args) | |||||
| set_tracing() | |||||
| return ret | |||||
| return active_trace._apply_op(op, args) | return active_trace._apply_op(op, args) | ||||
| apply.disable(apply_compiled_mode) | |||||
| @apply.register() | |||||
| def apply_const_compiled_mode(op: Const, *args: RawTensor): | |||||
| def apply_const_compiled_mode(value, dtype, device, is_const): | |||||
| if skip_tracing: | if skip_tracing: | ||||
| args = [ | args = [ | ||||
| as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | |||||
| RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | |||||
| for x in args | for x in args | ||||
| ] | ] | ||||
| return apply.super(op, *args) | |||||
| return active_trace._apply_const(op, args) | |||||
| apply.disable(apply_const_compiled_mode) | |||||
| unset_tracing() | |||||
| ret = RawTensor(value, dtype, device, False) | |||||
| set_tracing() | |||||
| return ret | |||||
| return active_trace._apply_const(value, dtype, device) | |||||
| # this hook injects TraceMixin | # this hook injects TraceMixin | ||||
| @apply.register() | |||||
| def apply_with_tracing(op: OpDef, *args: RawTensor): | def apply_with_tracing(op: OpDef, *args: RawTensor): | ||||
| outputs = apply.super(op, *args) | |||||
| active_trace._record_op(op, args, outputs) | |||||
| return outputs | |||||
| apply.disable(apply_with_tracing) | |||||
| @apply.register() | |||||
| def apply_const_with_tracing(op: Const, *args: RawTensor): | |||||
| outputs = apply.super(op, *args) | |||||
| active_trace._record_const(op, outputs) | |||||
| return outputs | |||||
| apply.disable(apply_const_with_tracing) | |||||
| class BrokenRawTensor(RawTensor): | |||||
| def __getattribute__(self, _): | |||||
| raise RuntimeError("broken due to misuse of tracing") | |||||
| def __setattr__(self, *_): | |||||
| raise RuntimeError("broken due to misuse of tracing") | |||||
| @functools.singledispatch | |||||
| def find_raw_tensor(x): | |||||
| return None | |||||
| @find_raw_tensor.register(RawTensor) | |||||
| def _(x): | |||||
| return x | |||||
| if active_trace._symbolic: | |||||
| outputs = apply_symbolic_mode(op, *args) | |||||
| else: | |||||
| unset_tracing() | |||||
| outputs = apply(op, *args) | |||||
| set_tracing() | |||||
| @find_raw_tensor.register(TensorWrapperBase) | |||||
| def _(x): | |||||
| x = getattr(x, "__wrapped__", None) | |||||
| if x is not None: | |||||
| return find_raw_tensor(x) | |||||
| active_trace._record_op(op, args, outputs) | |||||
| return list(outputs) | |||||
| @find_raw_tensor.register(Tensor) | |||||
| def _(x): | |||||
| x = getattr(x, "_data", None) | |||||
| if x is not None: | |||||
| return find_raw_tensor(x) | |||||
| def apply_const_with_tracing(value, dtype, device, is_const): | |||||
| if active_trace._symbolic: | |||||
| outputs = apply_const_symbolic_mode(value, dtype, device) | |||||
| else: | |||||
| unset_tracing() | |||||
| outputs = (RawTensor(value, dtype, device, False),) | |||||
| set_tracing() | |||||
| active_trace._record_const(outputs) | |||||
| return list(outputs) | |||||
| @@ -28,7 +28,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
| dmap_callback = None | dmap_callback = None | ||||
| q_dict = {"mode": None, "scale": None, "zero_point": None} | q_dict = {"mode": None, "scale": None, "zero_point": None} | ||||
| def __new__(cls, data, dtype=None, device=None): | |||||
| def __new__(cls, data, dtype=None, device=None, is_const=False): | |||||
| if device is None: | if device is None: | ||||
| cn = get_default_device() | cn = get_default_device() | ||||
| elif isinstance(device, str): | elif isinstance(device, str): | ||||
| @@ -40,6 +40,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
| assert isinstance(device, CompNode) | assert isinstance(device, CompNode) | ||||
| cn = device | cn = device | ||||
| # import pdb; pdb.set_trace() | |||||
| if isinstance(data, _Tensor): | if isinstance(data, _Tensor): | ||||
| obj = _Tensor.__new__(cls, data) | obj = _Tensor.__new__(cls, data) | ||||
| else: | else: | ||||
| @@ -47,7 +48,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
| if 0 in data.strides: | if 0 in data.strides: | ||||
| data = data.squeeze().reshape(data.shape) | data = data.squeeze().reshape(data.shape) | ||||
| obj = _Tensor.__new__(cls, data, dtype, cn) | |||||
| obj = _Tensor.__new__(cls, data, dtype, cn, is_const) | |||||
| return obj | return obj | ||||
| @property | @property | ||||
| @@ -296,7 +296,9 @@ void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta) | |||||
| Tensor* args[2] = {grad.get(), delta.get()}; | Tensor* args[2] = {grad.get(), delta.get()}; | ||||
| ctx.args = args; | ctx.args = args; | ||||
| ctx.flags = grad->m_flags | delta->m_flags; | ctx.flags = grad->m_flags | delta->m_flags; | ||||
| if (is_tracing) { | |||||
| ctx.flags |= Tensor::Flags::TRACE; | |||||
| } | |||||
| grad = apply(ctx)[0]; | grad = apply(ctx)[0]; | ||||
| } | } | ||||
| @@ -354,6 +356,9 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||||
| } | } | ||||
| ctx.args = args; | ctx.args = args; | ||||
| if (is_tracing) | |||||
| ctx.flags |= Tensor::Flags::TRACE; | |||||
| auto grads = apply(ctx); | auto grads = apply(ctx); | ||||
| size_t j = 0; | size_t j = 0; | ||||
| @@ -11,8 +11,10 @@ | |||||
| #include "./tensor.h" | #include "./tensor.h" | ||||
| #include "./grad.h" | #include "./grad.h" | ||||
| #include "./trace.h" | |||||
| #include "./common.h" | #include "./common.h" | ||||
| #include "./numpy_dtypes.h" | #include "./numpy_dtypes.h" | ||||
| #include "./graph_rt.h" | |||||
| #include <pybind11/numpy.h> | #include <pybind11/numpy.h> | ||||
| #include <pybind11/operators.h> | #include <pybind11/operators.h> | ||||
| @@ -23,6 +25,47 @@ namespace mgb::imperative::python { | |||||
| std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | ||||
| py::object cpp_apply_with_tracing, cpp_apply_const_with_tracing, | |||||
| cpp_apply_compiled_mode, cpp_apply_const_compiled_mode; | |||||
| py::object cpp_apply_backward_varnode; | |||||
| #define REGISTE_APPLY_FUNC(mode) \ | |||||
| void set_##mode(py::object pyf) { \ | |||||
| mode = pybind11::reinterpret_steal<py::object>(pyf); \ | |||||
| } | |||||
| REGISTE_APPLY_FUNC(cpp_apply_with_tracing) | |||||
| REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing) | |||||
| REGISTE_APPLY_FUNC(cpp_apply_compiled_mode) | |||||
| REGISTE_APPLY_FUNC(cpp_apply_const_compiled_mode) | |||||
| REGISTE_APPLY_FUNC(cpp_apply_backward_varnode) | |||||
| #undef REGISTE_APPLY_FUNC | |||||
| bool is_tracing = false; | |||||
| bool is_symbolic = false; | |||||
| bool is_compiled = false; | |||||
| int64_t call_level = 0; | |||||
| #define SET_UNSET_PROP(mode) \ | |||||
| void set_##mode() { \ | |||||
| is_##mode = true; \ | |||||
| } \ | |||||
| void unset_##mode() { \ | |||||
| is_##mode = false; \ | |||||
| } \ | |||||
| SET_UNSET_PROP(tracing) | |||||
| SET_UNSET_PROP(symbolic) | |||||
| SET_UNSET_PROP(compiled) | |||||
| #undef SET_UNSET_PROP | |||||
| bool skip_tracing = false; | |||||
| apply_result_t apply(ApplyContext& ctx) { | apply_result_t apply(ApplyContext& ctx) { | ||||
| // emulating scalar should be put to specific op's apply, e.g., | // emulating scalar should be put to specific op's apply, e.g., | ||||
| // elementwise, reduce, typecvt. Currently it's still handled at python | // elementwise, reduce, typecvt. Currently it's still handled at python | ||||
| @@ -36,7 +79,7 @@ apply_result_t apply(ApplyContext& ctx) { | |||||
| } | } | ||||
| if (ctx.flags & Tensor::Flags::TRACE) { | if (ctx.flags & Tensor::Flags::TRACE) { | ||||
| // TODO: trace | |||||
| return apply_trace(ctx); | |||||
| } else { | } else { | ||||
| SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs); | SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs); | ||||
| for (size_t i = 0; i < ctx.nargs; ++i) { | for (size_t i = 0; i < ctx.nargs; ++i) { | ||||
| @@ -58,7 +101,6 @@ apply_result_t apply(ApplyContext& ctx) { | |||||
| PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */) { | PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */) { | ||||
| try { | try { | ||||
| // if (kwnames && PyTuple_GET_SIZE(kwnames)) { | // if (kwnames && PyTuple_GET_SIZE(kwnames)) { | ||||
| // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed"); | // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed"); | ||||
| // return nullptr; | // return nullptr; | ||||
| @@ -67,6 +109,7 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||||
| PyErr_SetString(PyExc_TypeError, "expect Op"); | PyErr_SetString(PyExc_TypeError, "expect Op"); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto* op = args[0]; | auto* op = args[0]; | ||||
| PyTypeObject* pytype = args[1]->ob_type; | PyTypeObject* pytype = args[1]->ob_type; | ||||
| @@ -79,18 +122,23 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||||
| SmallVector<Tensor*, 64> tensors(nargs); | SmallVector<Tensor*, 64> tensors(nargs); | ||||
| ctx.args = &tensors[0]; | ctx.args = &tensors[0]; | ||||
| ctx.nargs = nargs; | ctx.nargs = nargs; | ||||
| if (strstr(op->ob_type->tp_name, "BackwardGraph")) { | |||||
| ctx.backward = true; | |||||
| } | |||||
| for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
| TensorWrapper* tw = TensorWrapper::cast_safe(args[i]); | |||||
| if (!tw) { | |||||
| if (TensorWrapper* tw = TensorWrapper::cast_safe(args[i])) { | |||||
| auto* t = tensors[i] = tw->m_tensor.get(); | |||||
| ctx.flags |= t->m_flags; | |||||
| } else { | |||||
| PyErr_SetString(PyExc_TypeError, "expect Tensor"); | PyErr_SetString(PyExc_TypeError, "expect Tensor"); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto* t = tensors[i] = tw->m_tensor.get(); | |||||
| ctx.flags |= t->m_flags; | |||||
| } | } | ||||
| // TODO: set TRACE flag | |||||
| if (is_tracing) { | |||||
| ctx.flags |= Tensor::Flags::TRACE; | |||||
| } | |||||
| auto outputs = apply(ctx); | auto outputs = apply(ctx); | ||||
| size_t nout = outputs.size(); | size_t nout = outputs.size(); | ||||
| @@ -99,7 +147,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||||
| ret[i] = TensorWrapper::make(pytype, std::move(outputs[i])); | ret[i] = TensorWrapper::make(pytype, std::move(outputs[i])); | ||||
| } | } | ||||
| return ret.release().ptr(); | return ret.release().ptr(); | ||||
| } catch (std::exception& e) { | } catch (std::exception& e) { | ||||
| PyErr_SetString(PyExc_RuntimeError, e.what()); | PyErr_SetString(PyExc_RuntimeError, e.what()); | ||||
| return nullptr; | return nullptr; | ||||
| @@ -122,36 +169,116 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
| } | } | ||||
| m_tensor = t->m_tensor; | m_tensor = t->m_tensor; | ||||
| } else { | } else { | ||||
| if (nargs != 3) { | |||||
| throw py::type_error("expect 3 arguments"); | |||||
| } | |||||
| py::detail::loader_life_support life_sup; // required to cast DType | |||||
| auto data = tup[0].cast<py::array>(); | |||||
| DType dtype = tup[1].cast<DType>(); | |||||
| CompNode cn = tup[2].cast<CompNode>(); | |||||
| interpreter::Interpreter::Handle handle; | |||||
| constexpr auto size_threshhold = TensorShape::MAX_NDIM; | |||||
| if (data.size() > size_threshhold) { | |||||
| handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype)); | |||||
| if (nargs == 1) { | |||||
| auto arg0 = PyTuple_GetItem(args, 0); | |||||
| // for lazy_eval_tensor | |||||
| if (strstr(arg0->ob_type->tp_name, "VarNode")) { | |||||
| if (PyObject_HasAttrString(arg0, "_node")) { | |||||
| arg0 = PyObject_GetAttrString(arg0, "_node"); | |||||
| } | |||||
| m_tensor = std::make_shared<Tensor>(py::handle(arg0).cast<cg::VarNode *>()); | |||||
| } else { | |||||
| // for DeviceTensorND | |||||
| if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) { | |||||
| auto dv = py::handle(arg0).cast<DeviceTensorND>(); | |||||
| interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv); | |||||
| m_tensor = std::make_shared<Tensor>(handle); | |||||
| } else { | |||||
| throw py::type_error("single argument is not tensor, varnode or devicetensor"); | |||||
| } | |||||
| } | |||||
| } else { | } else { | ||||
| HostTensorND ret(cn); | |||||
| handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); | |||||
| } | |||||
| py::detail::loader_life_support life_sup; // required to cast DType | |||||
| auto data = tup[0].cast<py::array>(); | |||||
| DType dtype = tup[1].cast<DType>(); | |||||
| CompNode cn = tup[2].cast<CompNode>(); | |||||
| bool is_const = tup[3].cast<bool>(); | |||||
| if (nargs != 4) { | |||||
| throw py::type_error("expect 3 arguments"); | |||||
| } | |||||
| // const op | |||||
| if (is_const && is_tracing) { | |||||
| py::object pyf; | |||||
| if (is_compiled) { | |||||
| pyf = cpp_apply_const_compiled_mode; | |||||
| } else { | |||||
| pyf = cpp_apply_const_with_tracing; | |||||
| } | |||||
| auto ret = pyf(*tup); | |||||
| auto py_ret = py::reinterpret_borrow<py::list>(ret); | |||||
| if (auto* t = cast_safe(py_ret[0].ptr())) { | |||||
| m_tensor = t->m_tensor; | |||||
| } | |||||
| return; | |||||
| } | |||||
| interpreter::Interpreter::Handle handle; | |||||
| constexpr auto size_threshhold = TensorShape::MAX_NDIM; | |||||
| if (data.size() > size_threshhold) { | |||||
| handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype)); | |||||
| } else { | |||||
| HostTensorND ret(cn); | |||||
| handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); | |||||
| } | |||||
| m_tensor = std::make_shared<Tensor>(handle); | |||||
| m_tensor = std::make_shared<Tensor>(handle); | |||||
| if (data.ndim() == 0) { | |||||
| m_tensor->m_flags |= Tensor::Flags::SCALAR; | |||||
| if (data.ndim() == 0) { | |||||
| m_tensor->m_flags |= Tensor::Flags::SCALAR; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| #define REGISTE_TENSORWRAPPER_FUNC(type, member) \ | |||||
| PyObject* TensorWrapper::member() { \ | |||||
| return py::cast(m_tensor->m_trace_info.member).release().ptr(); \ | |||||
| } \ | |||||
| void TensorWrapper::set_##member(PyObject* dest) { \ | |||||
| auto py_dest = py::reinterpret_borrow<py::object>(dest); \ | |||||
| type real_dest = py_dest.cast<type>(); \ | |||||
| m_tensor->m_trace_info.member = real_dest; \ | |||||
| } | |||||
| REGISTE_TENSORWRAPPER_FUNC(bool, data_read) | |||||
| REGISTE_TENSORWRAPPER_FUNC(bool, value_read) | |||||
| REGISTE_TENSORWRAPPER_FUNC(bool, shape_read) | |||||
| REGISTE_TENSORWRAPPER_FUNC(int64_t, mixin_handle) | |||||
| #undef REGISTE_TENSORWRAPPER_FUNC | |||||
| PyObject* TensorWrapper::handle() { | |||||
| return py::cast(m_tensor->m_handle).release().ptr(); | |||||
| } | |||||
| void TensorWrapper::set_handle(PyObject* dest) { | |||||
| auto py_dest = py::reinterpret_borrow<py::object>(dest); | |||||
| SharedHandle real_dest = py_dest.cast<SharedHandle>(); | |||||
| auto&& t = std::move(m_tensor->m_handle); | |||||
| m_tensor->m_handle = std::move(real_dest); | |||||
| } | |||||
| PyObject* TensorWrapper::shape() { | PyObject* TensorWrapper::shape() { | ||||
| if (!skip_tracing) { | |||||
| set_shape_read(py::cast(true). release().ptr()); | |||||
| } | |||||
| if (m_tensor->m_flags & Tensor::Flags::SCALAR) { | if (m_tensor->m_flags & Tensor::Flags::SCALAR) { | ||||
| return PyTuple_New(0); | return PyTuple_New(0); | ||||
| } | } | ||||
| auto&& shape = m_tensor->shape(); | |||||
| TensorShape shape; | |||||
| if (m_tensor->m_var) { | |||||
| shape = m_tensor->m_var->shape(); | |||||
| } else { | |||||
| shape = m_tensor->shape(); | |||||
| } | |||||
| if (!shape.ndim) { | if (!shape.ndim) { | ||||
| Py_RETURN_NONE; | Py_RETURN_NONE; | ||||
| } | } | ||||
| @@ -164,16 +291,38 @@ PyObject* TensorWrapper::shape() { | |||||
| PyObject* TensorWrapper::dtype() { | PyObject* TensorWrapper::dtype() { | ||||
| if (m_tensor->m_var) { | |||||
| return py::cast(m_tensor->m_var->dtype()).release().ptr(); | |||||
| } | |||||
| return py::cast(m_tensor->dtype()).release().ptr(); | return py::cast(m_tensor->dtype()).release().ptr(); | ||||
| } | } | ||||
| PyObject* TensorWrapper::device() { | PyObject* TensorWrapper::device() { | ||||
| if (m_tensor->m_var) { | |||||
| return py::cast(m_tensor->m_var->comp_node()).release().ptr(); | |||||
| } | |||||
| return py::cast(m_tensor->comp_node()).release().ptr(); | return py::cast(m_tensor->comp_node()).release().ptr(); | ||||
| } | } | ||||
| PyObject* TensorWrapper::numpy() { | PyObject* TensorWrapper::numpy() { | ||||
| if (!skip_tracing) { | |||||
| set_value_read(py::cast(true).release().ptr()); | |||||
| } | |||||
| if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) { | |||||
| auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); | |||||
| auto&& type = mgr.get_infer_type(m_tensor->m_var); | |||||
| using InferType = cg::static_infer::InferType; | |||||
| if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) { | |||||
| return nullptr; | |||||
| } | |||||
| auto* val = mgr.infer_value_fallible(m_tensor->m_var); | |||||
| if (!val) { | |||||
| return nullptr; | |||||
| } | |||||
| return py::cast(*val).attr("numpy")().release().ptr(); | |||||
| } | |||||
| auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get()); | auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get()); | ||||
| auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); | auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); | ||||
| if (!arr) return nullptr; | if (!arr) return nullptr; | ||||
| @@ -184,6 +333,13 @@ PyObject* TensorWrapper::numpy() { | |||||
| return arr.release().ptr(); | return arr.release().ptr(); | ||||
| } | } | ||||
| PyObject* TensorWrapper::varnode() { | |||||
| if (m_tensor->m_var) { | |||||
| return py::cast(m_tensor->m_var).release().ptr(); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| void TensorWrapper::reset(PyObject* tensor) { | void TensorWrapper::reset(PyObject* tensor) { | ||||
| TensorWrapper* t = TensorWrapper::cast_safe(tensor); | TensorWrapper* t = TensorWrapper::cast_safe(tensor); | ||||
| if (!t) { | if (!t) { | ||||
| @@ -195,13 +351,22 @@ void TensorWrapper::reset(PyObject* tensor) { | |||||
| PyObject* TensorWrapper::detach() { | PyObject* TensorWrapper::detach() { | ||||
| PyObject* self = wrap_t::pycast(this); | PyObject* self = wrap_t::pycast(this); | ||||
| PyTypeObject* pytype = self->ob_type; | PyTypeObject* pytype = self->ob_type; | ||||
| auto new_tensor = std::make_shared<Tensor>(m_tensor->m_handle); | |||||
| std::shared_ptr<Tensor> new_tensor; | |||||
| if (m_tensor->m_handle.get()) { | |||||
| new_tensor = std::make_shared<Tensor>(m_tensor->m_handle); | |||||
| } else { | |||||
| new_tensor = std::make_shared<Tensor>(m_tensor->m_var); | |||||
| } | |||||
| auto ret = TensorWrapper::make(pytype, std::move(new_tensor)); | auto ret = TensorWrapper::make(pytype, std::move(new_tensor)); | ||||
| return ret.release().ptr(); | return ret.release().ptr(); | ||||
| } | } | ||||
| PyObject* TensorWrapper::_dev_tensor(){ | PyObject* TensorWrapper::_dev_tensor(){ | ||||
| if (!skip_tracing) { | |||||
| set_data_read(py::cast(true).release().ptr()); | |||||
| } | |||||
| auto dev_tensor = interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get()); | auto dev_tensor = interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get()); | ||||
| return py::cast(dev_tensor).release().ptr(); | return py::cast(dev_tensor).release().ptr(); | ||||
| } | } | ||||
| @@ -227,11 +392,14 @@ PyObject* TensorWrapper::isscalar() { | |||||
| } | } | ||||
| } | } | ||||
| void TensorWrapper::setscalar() { | void TensorWrapper::setscalar() { | ||||
| m_tensor->m_flags |= Tensor::Flags::SCALAR; | m_tensor->m_flags |= Tensor::Flags::SCALAR; | ||||
| } | } | ||||
| PyMethodDef apply_def{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr}; | |||||
| struct TensorWeakRef { | struct TensorWeakRef { | ||||
| std::weak_ptr<Tensor> wptr; | std::weak_ptr<Tensor> wptr; | ||||
| @@ -262,6 +430,12 @@ void init_tensor(py::module m) { | |||||
| .def<&TensorWrapper::_swap_out>("_swap_out") | .def<&TensorWrapper::_swap_out>("_swap_out") | ||||
| .def<&TensorWrapper::_swap_in>("_swap_in") | .def<&TensorWrapper::_swap_in>("_swap_in") | ||||
| .def<&TensorWrapper::_drop>("_drop") | .def<&TensorWrapper::_drop>("_drop") | ||||
| .def_getset<&TensorWrapper::varnode>("_varnode") | |||||
| .def_getset<&TensorWrapper::data_read, &TensorWrapper::set_data_read>("data_read") | |||||
| .def_getset<&TensorWrapper::value_read, &TensorWrapper::set_value_read>("value_read") | |||||
| .def_getset<&TensorWrapper::shape_read, &TensorWrapper::set_shape_read>("shape_read") | |||||
| .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("mixin_handle") | |||||
| .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") | |||||
| .finalize(); | .finalize(); | ||||
| if (!tensor_type) throw py::error_already_set(); | if (!tensor_type) throw py::error_already_set(); | ||||
| py::setattr(m, "Tensor", tensor_type); | py::setattr(m, "Tensor", tensor_type); | ||||
| @@ -296,6 +470,25 @@ void init_tensor(py::module m) { | |||||
| if (!grad_key_type) throw py::error_already_set(); | if (!grad_key_type) throw py::error_already_set(); | ||||
| py::setattr(m, "GradKey", grad_key_type); | py::setattr(m, "GradKey", grad_key_type); | ||||
| py::setattr(m, "backward", py::cpp_function(&GradKeyWrapper::backward)); | py::setattr(m, "backward", py::cpp_function(&GradKeyWrapper::backward)); | ||||
| m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing); | |||||
| m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing); | |||||
| m.def("set_cpp_apply_compiled_mode", &set_cpp_apply_compiled_mode); | |||||
| m.def("set_cpp_apply_const_compiled_mode", &set_cpp_apply_const_compiled_mode); | |||||
| m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode); | |||||
| m.attr("skip_tracing") = &skip_tracing; | |||||
| m.attr("call_level") = &call_level; | |||||
| py::class_<SharedHandle>(m, "SharedHandle") | |||||
| .def(py::init<const SharedHandle&>()); | |||||
| m.def("set_tracing", &set_tracing); | |||||
| m.def("unset_tracing", &unset_tracing); | |||||
| m.def("set_symbolic", &set_symbolic); | |||||
| m.def("unset_symbolic", &unset_symbolic); | |||||
| m.def("set_compiled", &set_compiled); | |||||
| m.def("unset_compiled", &unset_compiled); | |||||
| } | } | ||||
| } // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||
| @@ -30,13 +30,10 @@ struct ObjectPtr : B { | |||||
| } // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||
| #include "./grad_info.h" // for struct GradInfo | #include "./grad_info.h" // for struct GradInfo | ||||
| #include "./trace_info.h" // for struct TraceInfo | |||||
| namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
| struct TraceInfo { | |||||
| }; | |||||
| extern std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | extern std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | ||||
| class SharedHandle { | class SharedHandle { | ||||
| @@ -46,7 +43,9 @@ class SharedHandle { | |||||
| public: | public: | ||||
| inline explicit SharedHandle(Handle handle) : holder(handle, [](auto* h){ | inline explicit SharedHandle(Handle handle) : holder(handle, [](auto* h){ | ||||
| interpreter_for_py->del(h); | |||||
| if (h) { | |||||
| interpreter_for_py->del(h); | |||||
| } | |||||
| }) {} | }) {} | ||||
| SharedHandle(const SharedHandle&) = default; | SharedHandle(const SharedHandle&) = default; | ||||
| SharedHandle& operator=(const SharedHandle&) = default; | SharedHandle& operator=(const SharedHandle&) = default; | ||||
| @@ -71,11 +70,14 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||||
| GradInfo m_grad_info; | GradInfo m_grad_info; | ||||
| TraceInfo m_trace_info; | TraceInfo m_trace_info; | ||||
| SharedHandle m_handle; | SharedHandle m_handle; | ||||
| cg::VarNode* m_var; | |||||
| using Handle = interpreter::Interpreter::Handle; | using Handle = interpreter::Interpreter::Handle; | ||||
| inline explicit Tensor(Handle handle) : m_handle(handle) {} | |||||
| inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)) {} | |||||
| inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {} | |||||
| inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)), m_var(nullptr) {} | |||||
| inline explicit Tensor(cg::VarNode *var) : m_handle(nullptr), m_var(var) {} | |||||
| ~Tensor() = default; | ~Tensor() = default; | ||||
| inline std::shared_ptr<Tensor> copy() { | inline std::shared_ptr<Tensor> copy() { | ||||
| @@ -83,12 +85,28 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||||
| ret->m_flags = m_flags; | ret->m_flags = m_flags; | ||||
| ret->m_grad_info = m_grad_info; | ret->m_grad_info = m_grad_info; | ||||
| ret->m_trace_info = m_trace_info; | ret->m_trace_info = m_trace_info; | ||||
| ret->m_var = m_var; | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| inline DType dtype() {return interpreter_for_py->get_dtype(m_handle.get());} | |||||
| inline CompNode comp_node() {return interpreter_for_py->get_device(m_handle.get());} | |||||
| inline TensorShape shape() {return interpreter_for_py->get_shape(m_handle.get());} | |||||
| inline DType dtype() { | |||||
| if (m_var) { | |||||
| return m_var->dtype(); | |||||
| } | |||||
| return interpreter_for_py->get_dtype(m_handle.get()); | |||||
| } | |||||
| inline CompNode comp_node() { | |||||
| if (m_var) { | |||||
| return m_var->comp_node(); | |||||
| } | |||||
| return interpreter_for_py->get_device(m_handle.get()); | |||||
| } | |||||
| inline TensorShape shape() { | |||||
| if (m_var) { | |||||
| return m_var->shape(); | |||||
| } | |||||
| return interpreter_for_py->get_shape(m_handle.get()); | |||||
| } | |||||
| }; | }; | ||||
| @@ -135,6 +153,19 @@ struct TensorWrapper { | |||||
| void _swap_in(); | void _swap_in(); | ||||
| void _swap_out(); | void _swap_out(); | ||||
| void _drop(); | void _drop(); | ||||
| PyObject* varnode(); | |||||
| PyObject* handle(); | |||||
| void set_handle(PyObject *); | |||||
| PyObject* data_read(); | |||||
| PyObject* value_read(); | |||||
| PyObject* shape_read(); | |||||
| PyObject* mixin_handle(); | |||||
| void set_data_read(PyObject*); | |||||
| void set_value_read(PyObject*); | |||||
| void set_shape_read(PyObject*); | |||||
| void set_mixin_handle(PyObject*); | |||||
| }; | }; | ||||
| @@ -145,6 +176,7 @@ struct ApplyContext { | |||||
| std::shared_ptr<OpDef> op; | std::shared_ptr<OpDef> op; | ||||
| Tensor*const* args; | Tensor*const* args; | ||||
| size_t nargs; | size_t nargs; | ||||
| bool backward = false; | |||||
| }; | }; | ||||
| using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>; | using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>; | ||||
| @@ -153,6 +185,14 @@ apply_result_t apply(ApplyContext& ctx); | |||||
| void init_tensor(pybind11::module); | void init_tensor(pybind11::module); | ||||
| extern bool is_tracing; | |||||
| extern bool is_symbolic; | |||||
| extern bool is_compiled; | |||||
| extern int64_t call_level; | |||||
| extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode; | |||||
| extern pybind11::object cpp_apply_backward_varnode; | |||||
| } // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||
| namespace pybind11::detail { | namespace pybind11::detail { | ||||
| @@ -0,0 +1,94 @@ | |||||
| /** | |||||
| * \file imperative/python/src/trace.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "./trace.h" | |||||
| #include "./helper.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| namespace py = pybind11; | |||||
| namespace mgb::imperative::python { | |||||
| apply_result_t apply_tensor_on_var_node(ApplyContext& ctx) { | |||||
| apply_result_t outputs; | |||||
| cg::VarNodeArray vinputs(ctx.nargs); | |||||
| for (size_t i = 0; i < ctx.nargs; i++) { | |||||
| vinputs[i] = ctx.args[i]->m_var; | |||||
| } | |||||
| auto ovars = OpDef::apply_on_var_node(*ctx.op, vinputs); | |||||
| for (size_t i = 0; i < ovars.size(); i++) { | |||||
| outputs.emplace_back(std::make_shared<Tensor>(ovars[i])); | |||||
| } | |||||
| return outputs; | |||||
| } | |||||
| apply_result_t apply_trace(ApplyContext& ctx) { | |||||
| apply_result_t outputs; | |||||
| bool run_apply_on_var_node = false; | |||||
| for (size_t i = 0; i < ctx.nargs; i++) { | |||||
| run_apply_on_var_node |= ((ctx.args[i]->m_handle.get() == nullptr) & (ctx.args[i]->m_var != nullptr)); | |||||
| } | |||||
| if (ctx.backward) { | |||||
| // reach here when symbolic=True or compiled=True | |||||
| // call megbrain_graph.py apply(BackwardGraph, *args) | |||||
| auto args = py::tuple(ctx.nargs); | |||||
| for (size_t i = 0; i < ctx.nargs; i++) { | |||||
| args[i] = py::cast(ctx.args[i]->m_var); | |||||
| } | |||||
| py::object ret = cpp_apply_backward_varnode(py::cast(ctx.op), *args); | |||||
| if (!ret) { | |||||
| throw py::value_error("invalid py object call"); | |||||
| } | |||||
| // assumption: python function always returns PyList | |||||
| auto tup = py::reinterpret_borrow<py::list>(ret); | |||||
| for (auto i = 0; i < tup.size(); i++) { | |||||
| auto pitem = tup[i].cast<cg::VarNode *>(); | |||||
| outputs.emplace_back(std::make_shared<Tensor>(pitem)); | |||||
| } | |||||
| return outputs; | |||||
| } | |||||
| if (run_apply_on_var_node && !is_symbolic) { | |||||
| return apply_tensor_on_var_node(ctx); | |||||
| } | |||||
| py::object pyf; | |||||
| if (is_compiled) { | |||||
| // run apply in compiled mode, step 2, 3, etc | |||||
| pyf = cpp_apply_compiled_mode; | |||||
| } else { | |||||
| // run first step, both symbolic and non symbolic | |||||
| pyf = cpp_apply_with_tracing; | |||||
| } | |||||
| auto args = py::tuple(ctx.nargs); | |||||
| for (size_t i = 0; i < ctx.nargs; i++) { | |||||
| args[i] = TensorWrapper::make(std::move(std::shared_ptr<Tensor>(ctx.args[i]))).release(); | |||||
| } | |||||
| auto ret = pyf(py::cast(ctx.op), *args); | |||||
| // assumption: python function always returns PyList | |||||
| auto tup = py::reinterpret_borrow<py::list>(ret); | |||||
| for (auto i = 0; i < tup.size(); i++) { | |||||
| auto tw = TensorWrapper::cast_safe(tup[i].ptr()); | |||||
| outputs.emplace_back(tw->m_tensor); | |||||
| } | |||||
| return outputs; | |||||
| } | |||||
| } // namespace mgb::imperative::python | |||||
| @@ -9,9 +9,10 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #include "./tensor.h" | |||||
| namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
| struct TraceInfo { | |||||
| }; | |||||
| apply_result_t apply_trace(ApplyContext& ctx); | |||||
| } // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||
| @@ -0,0 +1,24 @@ | |||||
| /** | |||||
| * \file imperative/python/src/trace_info.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "inttypes.h" | |||||
| namespace mgb::imperative::python { | |||||
| struct TraceInfo { | |||||
| int64_t mixin_handle = -1; | |||||
| bool data_read = false; | |||||
| bool value_read = false; | |||||
| bool shape_read = false; | |||||
| }; | |||||
| } // namespace mgb::imperative::python | |||||
| @@ -19,8 +19,6 @@ from megengine import tensor | |||||
| from megengine.core._trace_option import set_symbolic_shape | from megengine.core._trace_option import set_symbolic_shape | ||||
| from megengine.core.ops import builtin as ops | from megengine.core.ops import builtin as ops | ||||
| from megengine.core.ops.builtin import Elemwise | from megengine.core.ops.builtin import Elemwise | ||||
| from megengine.core.tensor.core import apply | |||||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | |||||
| from megengine.core.tensor.utils import isscalar | from megengine.core.tensor.utils import isscalar | ||||
| from megengine.functional import exp, log | from megengine.functional import exp, log | ||||
| from megengine.jit import exclude_from_trace, trace | from megengine.jit import exclude_from_trace, trace | ||||
| @@ -32,35 +30,32 @@ def test_trace(): | |||||
| @trace(symbolic=symbolic) | @trace(symbolic=symbolic) | ||||
| def f(x): | def f(x): | ||||
| op = ops.Elemwise(Elemwise.Mode.NEGATE) | |||||
| (y,) = apply(op, x) | |||||
| return y | |||||
| return -x | |||||
| x = as_raw_tensor([1]).numpy() | |||||
| y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||||
| x = tensor([1]) | |||||
| y = f(x).numpy() | |||||
| for i in range(3): | for i in range(3): | ||||
| np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||||
| np.testing.assert_equal(f(x).numpy(), y) | |||||
| def test_exclude_from_trace(): | def test_exclude_from_trace(): | ||||
| for symbolic in [False, True]: | |||||
| for symbolic in [False]: | |||||
| @trace(symbolic=symbolic) | @trace(symbolic=symbolic) | ||||
| def f(x): | def f(x): | ||||
| neg = ops.Elemwise(Elemwise.Mode.NEGATE) | |||||
| (x,) = apply(neg, x) | |||||
| x = -x | |||||
| with exclude_from_trace(): | with exclude_from_trace(): | ||||
| if i % 2: | if i % 2: | ||||
| (x,) = apply(neg, x) | |||||
| (x,) = apply(neg, x) | |||||
| x = -x | |||||
| x = -x | |||||
| return x | return x | ||||
| x = as_raw_tensor([1]).numpy() | |||||
| x = tensor([1]) | |||||
| for i in range(3): | for i in range(3): | ||||
| y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||||
| np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||||
| y = f(x).numpy() | |||||
| np.testing.assert_equal(f(x).numpy(), y) | |||||
| def test_print_in_trace(): | def test_print_in_trace(): | ||||
| @@ -69,36 +64,33 @@ def test_print_in_trace(): | |||||
| @trace(symbolic=symbolic) | @trace(symbolic=symbolic) | ||||
| def f(x): | def f(x): | ||||
| nonlocal buf | nonlocal buf | ||||
| neg = ops.Elemwise(Elemwise.Mode.NEGATE) | |||||
| (x,) = apply(neg, x) | |||||
| x = -x | |||||
| buf = x.numpy() | buf = x.numpy() | ||||
| (x,) = apply(neg, x) | |||||
| x = -x | |||||
| return x | return x | ||||
| buf = None | buf = None | ||||
| x = as_raw_tensor([1]).numpy() | |||||
| x = tensor([1]) | |||||
| for i in range(3): | for i in range(3): | ||||
| y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||||
| y = f(x).numpy() | |||||
| z = buf | z = buf | ||||
| buf = None | buf = None | ||||
| np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||||
| np.testing.assert_equal(f(x).numpy(), y) | |||||
| np.testing.assert_equal(z, buf) | np.testing.assert_equal(z, buf) | ||||
| def test_dump(): | def test_dump(): | ||||
| @trace(symbolic=True, capture_as_const=True) | @trace(symbolic=True, capture_as_const=True) | ||||
| def f(a, b): | def f(a, b): | ||||
| op = ops.Elemwise(Elemwise.Mode.ADD) | |||||
| (y,) = apply(op, a, b) | |||||
| return y | |||||
| return a + b | |||||
| a = as_raw_tensor([2]).numpy() | |||||
| b = as_raw_tensor([4]).numpy() | |||||
| y = f.__wrapped__(as_raw_tensor(a), as_raw_tensor(b)).numpy() | |||||
| a = tensor([2]) | |||||
| b = tensor([4]) | |||||
| y = f(a, b).numpy() | |||||
| for i in range(3): | for i in range(3): | ||||
| np.testing.assert_equal(f(as_raw_tensor(a), as_raw_tensor(b)).numpy(), y) | |||||
| np.testing.assert_equal(f(a, b).numpy(), y) | |||||
| file = io.BytesIO() | file = io.BytesIO() | ||||
| dump_info = f.dump(file) | dump_info = f.dump(file) | ||||
| @@ -111,19 +103,17 @@ def test_dump(): | |||||
| def test_capture_dump(): | def test_capture_dump(): | ||||
| a = as_raw_tensor([2]) | |||||
| a = tensor([2]) | |||||
| @trace(symbolic=True, capture_as_const=True) | @trace(symbolic=True, capture_as_const=True) | ||||
| def f(x): | def f(x): | ||||
| op = ops.Elemwise(Elemwise.Mode.MUL) | |||||
| (y,) = apply(op, x, a) | |||||
| return y | |||||
| return x * a | |||||
| x = as_raw_tensor([3]).numpy() | |||||
| y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||||
| x = tensor([3]) | |||||
| y = f(x).numpy() | |||||
| for i in range(3): | for i in range(3): | ||||
| np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||||
| np.testing.assert_equal(f(x).numpy(), y) | |||||
| file = io.BytesIO() | file = io.BytesIO() | ||||
| f.dump(file) | f.dump(file) | ||||
| @@ -133,19 +123,17 @@ def test_capture_dump(): | |||||
| def test_dump_volatile(): | def test_dump_volatile(): | ||||
| p = as_raw_tensor([2]) | |||||
| p = tensor([2]) | |||||
| @trace(symbolic=True, capture_as_const=True) | @trace(symbolic=True, capture_as_const=True) | ||||
| def f(x): | def f(x): | ||||
| op = ops.Elemwise(Elemwise.Mode.MUL) | |||||
| (y,) = apply(op, x, p) | |||||
| return y | |||||
| return x * p | |||||
| x = as_raw_tensor([3]).numpy() | |||||
| y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||||
| x = tensor([3]) | |||||
| y = f(x).numpy() | |||||
| for i in range(3): | for i in range(3): | ||||
| np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||||
| np.testing.assert_equal(f(x).numpy(), y) | |||||
| file = io.BytesIO() | file = io.BytesIO() | ||||
| f.dump(file, optimize_for_inference=False) | f.dump(file, optimize_for_inference=False) | ||||
| @@ -163,21 +151,18 @@ def test_trace_profiler(): | |||||
| @trace(symbolic=symbolic, profiling=True) | @trace(symbolic=symbolic, profiling=True) | ||||
| def f(x): | def f(x): | ||||
| op = ops.Elemwise(Elemwise.Mode.NEGATE) | |||||
| (y,) = apply(op, x) | |||||
| return y | |||||
| return -x | |||||
| x = as_raw_tensor([1]).numpy() | |||||
| y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||||
| x = tensor([1]) | |||||
| y = f(x).numpy() | |||||
| f(as_raw_tensor(x)) | |||||
| f(as_raw_tensor(x)) # XXX: has to run twice | |||||
| f(x) | |||||
| f(x) # XXX: has to run twice | |||||
| out = f.get_profile() | out = f.get_profile() | ||||
| assert out.get("profiler") | assert out.get("profiler") | ||||
| @pytest.mark.skip(reason="force opt_level=0 when building graph") | |||||
| def test_goptions(): | def test_goptions(): | ||||
| @trace(symbolic=True, opt_level=0, capture_as_const=True) | @trace(symbolic=True, opt_level=0, capture_as_const=True) | ||||
| def f(x): | def f(x): | ||||
| @@ -196,7 +181,6 @@ def test_goptions(): | |||||
| np.testing.assert_equal(g(d).numpy().item(), 1.0) | np.testing.assert_equal(g(d).numpy().item(), 1.0) | ||||
| @pytest.mark.skip(reason="force opt_level=0 when building graph") | |||||
| def test_goptions_log_sum_exp(): | def test_goptions_log_sum_exp(): | ||||
| @trace(symbolic=True, opt_level=0, capture_as_const=True) | @trace(symbolic=True, opt_level=0, capture_as_const=True) | ||||
| def f(x, y): | def f(x, y): | ||||
| @@ -256,8 +240,7 @@ def test_optimize_for_inference_broadcast(): | |||||
| @trace(capture_as_const=True, symbolic_shape=True) | @trace(capture_as_const=True, symbolic_shape=True) | ||||
| def f(): | def f(): | ||||
| (b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32)) | |||||
| return b | |||||
| return a._broadcast(tensor([1, 10], dtype=np.int32)) | |||||
| f() | f() | ||||
| f.dump(io.BytesIO()) | f.dump(io.BytesIO()) | ||||
| @@ -387,7 +370,9 @@ def test_trace_nms(): | |||||
| @trace(symbolic=False) | @trace(symbolic=False) | ||||
| def f(boxes, scores): | def f(boxes, scores): | ||||
| # with tracing, max_output must be specified | |||||
| results = F.nn.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20) | results = F.nn.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20) | ||||
| # without tracing, max output can be inferred inside nms | |||||
| with exclude_from_trace(): | with exclude_from_trace(): | ||||
| _ = F.nn.nms(boxes, scores=scores, iou_thresh=0.5) | _ = F.nn.nms(boxes, scores=scores, iou_thresh=0.5) | ||||
| return results | return results | ||||
| @@ -318,7 +318,6 @@ def optimize_for_inference(args, outputs): | |||||
| ), "optimize_for_inference should be set when {} is given".format(k) | ), "optimize_for_inference should be set when {} is given".format(k) | ||||
| kwargs[v] = True | kwargs[v] = True | ||||
| outputs = [G.VarNode(output) for output in outputs] | |||||
| if args.optimize_for_inference: | if args.optimize_for_inference: | ||||
| outputs = [i._node for i in G.optimize_for_inference(outputs, **kwargs)] | outputs = [i._node for i in G.optimize_for_inference(outputs, **kwargs)] | ||||
| @@ -84,7 +84,7 @@ def main(): | |||||
| minibatch = next(val_dataset) | minibatch = next(val_dataset) | ||||
| net.eval() | net.eval() | ||||
| _, loss = val_fun(data, label) | _, loss = val_fun(data, label) | ||||
| loss = loss.numpy()[0] | |||||
| loss = loss.numpy() | |||||
| val_loss.append((step, loss)) | val_loss.append((step, loss)) | ||||
| print("Step: {} loss={}".format(step, loss)) | print("Step: {} loss={}".format(step, loss)) | ||||
| opt.step() | opt.step() | ||||