| @@ -17,15 +17,31 @@ from ..ops.builtin import OpDef | |||
| from .core import OpBase, TensorBase, apply | |||
| class CompiledFunction: | |||
| def __init__(self, graph, function): | |||
| self._graph = graph | |||
| self._function = function | |||
| class Graph(_imperative_rt.ComputingGraph): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self._var_cache = weakref.WeakKeyDictionary() | |||
| self._op_cache = weakref.WeakKeyDictionary() | |||
| self._executor = ThreadPoolExecutor(1) | |||
| self._function = None | |||
| self._future = None | |||
| def _wrap(self, obj): | |||
| if type(obj) is _imperative_rt.VarNode: | |||
| wrapper, cache = VarNode, self._var_cache | |||
| elif type(obj) is _imperative_rt.OperatorNode: | |||
| wrapper, cache = OpNode, self._op_cache | |||
| if obj not in cache: | |||
| cache[obj] = wrapper(obj) | |||
| return cache[obj] | |||
| def compile(self, *args): | |||
| self._function = super().compile(_unwrap(args)) | |||
| return self | |||
| def execute(self, *args): | |||
| assert self._future is None | |||
| self._future = self._graph._executor.submit(self._function.execute, *args) | |||
| self._future = self._executor.submit(self._function.execute, *args) | |||
| def wait(self): | |||
| assert self._future is not None | |||
| @@ -40,30 +56,23 @@ class CompiledFunction: | |||
| self.execute(*args) | |||
| return self.wait() | |||
| def make_const(self, data, dtype=None, device=None): | |||
| if isinstance(data, _imperative_rt.DeviceTensorND): | |||
| assert dtype is None and device is None | |||
| return self._wrap(_imperative_rt.make_shared(self, data)) | |||
| else: | |||
| device = as_device(device).to_c() | |||
| return self._wrap(_imperative_rt.make_const(self, data, device, dtype)) | |||
| class Graph(_imperative_rt.ComputingGraph): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self._var_cache = weakref.WeakKeyDictionary() | |||
| self._op_cache = weakref.WeakKeyDictionary() | |||
| self._executor = ThreadPoolExecutor(1) | |||
| def _wrap(self, obj): | |||
| if type(obj) is _imperative_rt.VarNode: | |||
| wrapper, cache = VarNode, self._var_cache | |||
| elif type(obj) is _imperative_rt.OperatorNode: | |||
| wrapper, cache = OpNode, self._op_cache | |||
| if obj not in cache: | |||
| cache[obj] = wrapper(obj) | |||
| return cache[obj] | |||
| def compile(self, *args): | |||
| return CompiledFunction(self, super().compile(_unwrap(args))) | |||
| def make_input(self, *args: "VarNode", device=None, dtype=None, shape=None): | |||
| opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) | |||
| return opnode.outputs[0] | |||
| class VarNode(TensorBase): | |||
| def __init__(self, node: _imperative_rt.VarNode): | |||
| self._node = node | |||
| self.graph._var_cache[node] = self | |||
| @property | |||
| def graph(self) -> Graph: | |||
| @@ -81,10 +90,15 @@ class VarNode(TensorBase): | |||
| def device(self): | |||
| return as_device(self._node.comp_node) | |||
| @property | |||
| def shape(self): | |||
| return self._node.shape | |||
| class OpNode: | |||
| def __init__(self, node: _imperative_rt.OperatorNode): | |||
| self._node = node | |||
| self.graph._op_cache[node] = self | |||
| @property | |||
| def graph(self) -> Graph: | |||
| @@ -117,21 +131,21 @@ def _(op: OpDef, *args: VarNode): | |||
| return _wrap(outputs) | |||
| def input_callback(callback, *args, device=None, dtype=None, graph=None): | |||
| def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): | |||
| outputs = _imperative_rt.input_callback( | |||
| callback, as_device(device).to_c(), dtype, _unwrap(args), graph=graph | |||
| callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph | |||
| ) | |||
| value, dummy = _wrap(outputs) | |||
| return value, dummy | |||
| class InputNode(OpNode): | |||
| def __init__(self, *args: VarNode, device=None, dtype=None, graph=None): | |||
| def __init__(self, *args: VarNode, device=None, dtype=None, shape=None, graph=None): | |||
| r = _imperative_rt.DeviceTensorNDRendezvous() | |||
| if device is not None: | |||
| device = as_device(device).to_c() | |||
| outputs = _imperative_rt.input_callback( | |||
| r, device, dtype, _unwrap(args), graph=graph | |||
| r, device, dtype, shape, _unwrap(args), graph=graph | |||
| ) | |||
| super().__init__(outputs[0].owner) | |||
| self._rendezvous = r | |||
| @@ -169,6 +183,29 @@ class OutputNode(OpNode): | |||
| def get_value(self): | |||
| return self._rendezvous.get() | |||
| def drop_value(self): | |||
| self._rendezvous.drop() | |||
| def reset(self): | |||
| self._rendezvous.reset() | |||
| class ValueOutputNode(OpNode): | |||
| def __init__(self, var, *args): | |||
| args = (var,) + args | |||
| r = _imperative_rt.HostTensorNDRendezvous() | |||
| dummy = _imperative_rt.value_output_callback(r, _unwrap(args)) | |||
| super().__init__(dummy.owner) | |||
| self._rendezvous = r | |||
| def get_value(self): | |||
| hostnd, event = self._rendezvous.get() | |||
| event.wait() | |||
| return hostnd.numpy() | |||
| def drop_value(self): | |||
| self._rendezvous.drop() | |||
| def reset(self): | |||
| self._rendezvous.reset() | |||
| @@ -192,5 +229,8 @@ class AttrOutputNode(OpNode): | |||
| attr = self._rendezvous.get() | |||
| return TensorAttr(attr.shape, attr.dtype, as_device(attr.comp_node)) | |||
| def drop_value(self): | |||
| self._rendezvous.drop() | |||
| def reset(self): | |||
| self._rendezvous.reset() | |||
| @@ -31,11 +31,13 @@ class RawTensor(TensorBase): | |||
| _init_cb = None | |||
| _del_cb = None | |||
| _handle = None | |||
| def __init__(self, handle): | |||
| def __init__(self, handle=None): | |||
| self._handle = handle | |||
| if self._init_cb: | |||
| self._init_cb() | |||
| if handle is not None: | |||
| if self._init_cb: | |||
| self._init_cb() | |||
| @property | |||
| def dtype(self): | |||
| @@ -61,9 +63,10 @@ class RawTensor(TensorBase): | |||
| ) | |||
| def __del__(self): | |||
| if self._del_cb: | |||
| self._del_cb() | |||
| delete(self._handle) | |||
| if self._handle is not None: | |||
| if self._del_cb: | |||
| self._del_cb() | |||
| delete(self._handle) | |||
| @apply.register() | |||
| @@ -89,6 +92,11 @@ def as_raw_tensor(obj, dtype=None, device=None): | |||
| return as_raw_tensor(obj, device=device) | |||
| @as_raw_tensor.register(DeviceTensorND) | |||
| def _(data: DeviceTensorND): | |||
| return RawTensor(put(data)) | |||
| @as_raw_tensor.register(np.ndarray) | |||
| def _(array: np.ndarray, dtype=None, device=None): | |||
| device = None if device is None else as_device(device).to_c() | |||
| @@ -0,0 +1 @@ | |||
| from .tracing import exclude_from_trace, trace | |||
| @@ -0,0 +1,514 @@ | |||
| import contextlib | |||
| import functools | |||
| import typing | |||
| import weakref | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import megbrain_graph as G | |||
| from ..core.tensor.core import OpBase, apply | |||
| from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor | |||
| class TraceMismatchError(RuntimeError): | |||
| pass | |||
| active_trace = None | |||
| skip_tracing = False | |||
| @contextlib.contextmanager | |||
| def exclude_from_trace(): | |||
| global skip_tracing | |||
| if skip_tracing: | |||
| yield | |||
| return | |||
| try: | |||
| skip_tracing = True | |||
| if active_trace is not None: | |||
| active_trace._begin_excluded_region() | |||
| yield | |||
| finally: | |||
| skip_tracing = False | |||
| class TensorInfo: | |||
| __slots__ = ( | |||
| # collected attributes | |||
| "external", | |||
| "exported", | |||
| "data_read", | |||
| "shape_read", | |||
| "value_read", | |||
| "device", | |||
| "dtype", | |||
| "bound_data", | |||
| # resources for execution | |||
| "varnode", | |||
| "data_setter", | |||
| "shape_reader", | |||
| "value_reader", | |||
| "data_reader", | |||
| ) | |||
| def __init__(self): | |||
| self.exported = None | |||
| self.data_read = None | |||
| self.shape_read = None | |||
| self.value_read = None | |||
| self.bound_data = None | |||
| self.data_setter = None | |||
| self.shape_reader = None | |||
| self.value_reader = None | |||
| self.data_reader = None | |||
| class trace: | |||
| def __new__(cls, *args, **kwargs): | |||
| if not args: | |||
| return functools.partial(cls, **kwargs) | |||
| self = super().__new__(cls) | |||
| self.__init__(*args, **kwargs) | |||
| return self | |||
| def __init__(self, function, symbolic=False, capture_as_const=False): | |||
| self.__wrapped__ = function | |||
| self._symbolic = symbolic | |||
| self._capture_as_const = capture_as_const | |||
| self._capture_static_shape = False | |||
| self._untraced = True | |||
| self._tinfo = [] # handle -> TensorInfo | |||
| self._seq = [] | |||
| self._pc = 0 | |||
| self._graph = None | |||
| self._need_reset_nodes = None | |||
| self._lazy_eval_graph = None | |||
| self._lazy_eval_tensors = weakref.WeakSet() | |||
| self._active_tensors = weakref.WeakSet() | |||
| def _new_handle(self): | |||
| handle = len(self._tinfo) | |||
| info = TensorInfo() | |||
| self._tinfo.append(info) | |||
| return handle, info | |||
| def _apply_op(self, op, args): | |||
| assert not self._untraced | |||
| # check against trace | |||
| if self._pc >= len(self._seq): | |||
| raise TraceMismatchError("trace should end here, but more op observed") | |||
| record = self._seq[self._pc] | |||
| op_, ihandles, ohandles = record | |||
| if op != op_: | |||
| raise TraceMismatchError("op different from last time") | |||
| if len(ihandles) != len(args): | |||
| raise TraceMismatchError("op input size different from last time") | |||
| for h, x in zip(ihandles, args): | |||
| info = self._tinfo[h] | |||
| if info.external: | |||
| if ( | |||
| x.__class__ is CompiledTensorProxy | |||
| and not self._tinfo[x._CompiledTensorProxy__handle].exported | |||
| ): | |||
| raise TraceMismatchError( | |||
| "failed to capture: input was an external tensor " | |||
| "last time, got an internal tensor this time" | |||
| ) | |||
| if info.bound_data: | |||
| if x.__class__ is CompiledTensorProxy: | |||
| raise TraceMismatchError( | |||
| "const capture violated: was an external tensor " | |||
| "last time, got an internal tensor this time" | |||
| ) | |||
| if x._handle != info.bound_data._handle: | |||
| raise TraceMismatchError( | |||
| "const capture violated: got " | |||
| "a different tensor this time" | |||
| ) | |||
| else: | |||
| if info.dtype != x.dtype: | |||
| raise TraceMismatchError( | |||
| "failed to capture: different dtype from last time" | |||
| ) | |||
| if info.device != x.device: | |||
| raise TraceMismatchError( | |||
| "failed to capture: different device from last time" | |||
| ) | |||
| info.data_setter.set_value(x._dev_tensor()) | |||
| else: | |||
| if x.__class__ is not CompiledTensorProxy: | |||
| raise TraceMismatchError( | |||
| "unexpected capture: trying to use an external tensor as input, " | |||
| "but that input was an internal tensor last time" | |||
| ) | |||
| if x._CompiledTensorProxy__handle != h: | |||
| raise TraceMismatchError( | |||
| "mis-wiring: input edge to an data flow " | |||
| "graph node is different from last time" | |||
| ) | |||
| self._pc += 1 | |||
| outputs = tuple([CompiledTensorProxy(h) for h in ohandles]) | |||
| self._active_tensors.update(outputs) | |||
| return outputs | |||
| def _record_op(self, op, inputs, outputs): | |||
| if skip_tracing: | |||
| for x in inputs: | |||
| h = getattr(x, "_TraceMixin__handle", None) | |||
| if h is not None: | |||
| self._tinfo[h].data_read = True | |||
| return | |||
| ihandles = [] | |||
| for x in inputs: | |||
| h = getattr(x, "_TraceMixin__handle", None) | |||
| if h is None or (not self._capture_as_const and self._tinfo[h].exported): | |||
| h, info = self._new_handle() | |||
| info.external = True | |||
| info.device = x.device | |||
| info.dtype = x.dtype | |||
| if self._capture_as_const: | |||
| info.bound_data = x | |||
| ihandles.append(h) | |||
| ohandles = [] | |||
| for x in outputs: | |||
| h, info = self._new_handle() | |||
| ohandles.append(h) | |||
| info.external = False | |||
| TraceMixin._TraceMixin__inject(x, h) | |||
| self._seq.append((op, tuple(ihandles), tuple(ohandles))) | |||
| self._active_tensors.update(outputs) | |||
| @contextlib.contextmanager | |||
| def _setup(self): | |||
| global active_trace | |||
| if active_trace: | |||
| raise NotImplementedError("sorry, not implemented: nested trace") | |||
| active_trace = self | |||
| if self._untraced: | |||
| apply.enable(apply_with_tracing) | |||
| if self._symbolic: | |||
| apply.enable(apply_symbolic_mode) | |||
| self._lazy_eval_graph = G.Graph() | |||
| else: | |||
| apply.enable(apply_compiled_mode) | |||
| if self._graph is None: | |||
| self._compile() | |||
| self._graph.execute() | |||
| yield | |||
| escaped_tensors = tuple(self._active_tensors) | |||
| self._active_tensors.clear() | |||
| if self._untraced: | |||
| for x in escaped_tensors: | |||
| info = self._tinfo[x._TraceMixin__handle] | |||
| info.data_read = True | |||
| x._TraceMixin__restore() | |||
| if self._symbolic: | |||
| # eval lazy eval tensors | |||
| lazy_eval_tensors = tuple(self._lazy_eval_tensors) | |||
| if lazy_eval_tensors: | |||
| readers = [ | |||
| G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] | |||
| for x in lazy_eval_tensors | |||
| ] | |||
| self._lazy_eval_graph.compile(*readers) | |||
| self._lazy_eval_graph() | |||
| for r, x in zip(readers, lazy_eval_tensors): | |||
| assign_raw_tensor(x, as_raw_tensor(r.op.get_value())) | |||
| self._lazy_eval_graph = None | |||
| self._lazy_eval_tensors = None | |||
| self._untraced = False | |||
| else: | |||
| if self._pc != len(self._seq): | |||
| raise TraceMismatchError("premature end") | |||
| for x in escaped_tensors: | |||
| assign_raw_tensor(x, as_raw_tensor(x._dev_tensor())) | |||
| self._graph.wait() | |||
| self._reset_exec_env() | |||
| self._pc = 0 | |||
| apply.disable(apply_with_tracing) | |||
| apply.disable(apply_symbolic_mode) | |||
| apply.disable(apply_compiled_mode) | |||
| active_trace = None | |||
| def _begin_excluded_region(self): | |||
| if self._untraced: | |||
| # conditionally reading a compiled tensor in excluded region | |||
| # is permitted, so we have to assume every tensor might be read | |||
| for x in self._active_tensors: | |||
| info = self._tinfo[x._TraceMixin__handle] | |||
| info.exported = True | |||
| info.data_read = True | |||
| def _compile(self): | |||
| graph = self._graph = G.Graph() | |||
| # graph.options.graph_opt_level = 0 | |||
| need_reset_nodes = self._need_reset_nodes = [] | |||
| # links enforce ordering of I/O nodes | |||
| links = () | |||
| for op, ihandles, ohandles in self._seq: | |||
| ivars = [] | |||
| readers = [] | |||
| for h in ihandles: | |||
| info = self._tinfo[h] | |||
| if not hasattr(info, "varnode"): | |||
| assert info.external | |||
| if info.bound_data: | |||
| info.varnode = graph.make_const(info.bound_data._dev_tensor()) | |||
| else: | |||
| opnode = info.data_setter = G.InputNode( | |||
| *links, device=info.device, dtype=info.dtype, graph=graph | |||
| ) | |||
| need_reset_nodes.append(opnode) | |||
| info.varnode, *links = opnode.outputs | |||
| ivars.append(info.varnode) | |||
| ovars = apply(op, *ivars) | |||
| assert len(ovars) == len(ohandles) | |||
| for h, v in zip(ohandles, ovars): | |||
| info = self._tinfo[h] | |||
| info.varnode = v | |||
| def add_reader(opnode): | |||
| nonlocal links | |||
| need_reset_nodes.append(opnode) | |||
| readers.append(opnode.outputs[0]) | |||
| links = opnode.outputs | |||
| if info.data_read: | |||
| # Shape can be obtained from data so doesn't need its own | |||
| # output node. On the other hand, value is read separately | |||
| # to leverage eager h2d copy | |||
| info.shape_read = False | |||
| opnode = info.data_reader = G.OutputNode(v, *links) | |||
| add_reader(opnode) | |||
| if info.value_read: | |||
| opnode = info.value_reader = G.ValueOutputNode(v, *links) | |||
| add_reader(opnode) | |||
| if info.shape_read: | |||
| opnode = info.shape_reader = G.AttrOutputNode(v, *links) | |||
| add_reader(opnode) | |||
| graph.compile(*readers) | |||
| def _reset_exec_env(self): | |||
| for opnode in self._need_reset_nodes: | |||
| opnode.reset() | |||
| def _require_shape(self, handle): | |||
| info = self._tinfo[handle] | |||
| info.shape_read = True | |||
| def _require_value(self, handle): | |||
| info = self._tinfo[handle] | |||
| info.value_read = True | |||
| def _require_data(self, handle): | |||
| info = self._tinfo[handle] | |||
| info.data_read = True | |||
| def __call__(self, *args, **kwargs): | |||
| with self._setup(): | |||
| return self.__wrapped__(*args, **kwargs) | |||
| class CompiledTensorProxy(RawTensor): | |||
| """ | |||
| Duck-typed RawTensor | |||
| """ | |||
| def __init__(self, handle): | |||
| self.__handle = handle | |||
| self.__info = active_trace._tinfo[handle] | |||
| self.__shape = None | |||
| self.__data = None | |||
| self.__value = None | |||
| @property | |||
| def dtype(self): | |||
| return self.__info.varnode.dtype | |||
| @property | |||
| def device(self): | |||
| return self.__info.varnode.device | |||
| @property | |||
| def shape(self): | |||
| if self.__shape is None: | |||
| if self.__info.shape_read: | |||
| self.__shape = self.__info.shape_reader.get_value().shape | |||
| elif self.__info.data_read: | |||
| self.__shape = self._dev_tensor().shape | |||
| else: | |||
| raise TraceMismatchError("shape of this tensor is not read in trace") | |||
| return self.__shape | |||
| def numpy(self): | |||
| if self.__value is None: | |||
| if self.__info.value_read: | |||
| self.__value = self.__info.value_reader.get_value() | |||
| elif self.__info.data_read: | |||
| self.__value = self._dev_tensor().numpy() | |||
| else: | |||
| raise TraceMismatchError("value of this tensor is not read in trace") | |||
| return self.__value | |||
| def _dev_tensor(self): | |||
| if self.__data is None: | |||
| if not self.__info.data_read: | |||
| raise TraceMismatchError("raw data of this tensor is not read in trace") | |||
| self.__data = self.__info.data_reader.get_value() | |||
| return self.__data | |||
| def __del__(self): | |||
| if self.__info.shape_read and self.__shape is not None: | |||
| self.__info.shape_reader.drop_value() | |||
| if self.__info.value_read and self.__value is not None: | |||
| self.__info.value_reader.drop_value() | |||
| if self.__info.data_read and self.__data is not None: | |||
| self.__info.data_reader.drop_value() | |||
| class LazyEvalTensor(RawTensor): | |||
| def __init__(self, varnode): | |||
| self.__varnode = varnode | |||
| @property | |||
| def dtype(self): | |||
| return self.__varnode.dtype | |||
| @property | |||
| def device(self): | |||
| return self.__varnode.device | |||
| @property | |||
| def shape(self): | |||
| return self.__varnode.shape | |||
| def numpy(self): | |||
| raise RuntimeError("cannot read value during symbolic tracing") | |||
| def _dev_tensor(self): | |||
| raise RuntimeError("cannot access data during symbolic tracing") | |||
| class TraceMixin: | |||
| __subclass_cache = {} | |||
| def __inject(self, handle): | |||
| cache = __class__.__subclass_cache | |||
| cls = self.__class__ | |||
| subcls = cache.get(cls) | |||
| if subcls is None: | |||
| subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {}) | |||
| self.__class__ = subcls | |||
| self.__handle = handle | |||
| self.__cls = cls | |||
| return self | |||
| def __restore(self): | |||
| cls = self.__cls | |||
| del self.__handle | |||
| del self.__cls | |||
| self.__class__ = cls | |||
| return self | |||
| @property | |||
| def shape(self): | |||
| if not skip_tracing: | |||
| active_trace._require_shape(self.__handle) | |||
| return super().shape | |||
| def numpy(self): | |||
| if not skip_tracing: | |||
| active_trace._require_value(self.__handle) | |||
| return super().numpy() | |||
| def _dev_tensor(self): | |||
| if not skip_tracing: | |||
| active_trace._require_data(self.__handle) | |||
| return super()._dev_tensor() | |||
| class TracedRawTensor(TraceMixin, RawTensor): | |||
| pass | |||
| class TracedLazyTensor(TraceMixin, LazyEvalTensor): | |||
| pass | |||
| def assign_raw_tensor(lhs, rhs): | |||
| handle = rhs._handle | |||
| rhs.__dict__.clear() | |||
| lhs.__dict__.clear() | |||
| lhs.__class__ = RawTensor | |||
| lhs.__init__(handle) | |||
| # this hook turns RawTensor into LazyEvalTensor | |||
| @apply.register() | |||
| def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||
| graph = active_trace._lazy_eval_graph | |||
| ivars = [ | |||
| getattr(x, "_LazyEvalTensor__varnode", None) | |||
| or graph.make_const(x._dev_tensor()) | |||
| for x in args | |||
| ] | |||
| ovars = apply(op, *ivars) | |||
| outputs = [LazyEvalTensor(v) for v in ovars] | |||
| active_trace._lazy_eval_tensors.update(outputs) | |||
| return outputs | |||
| apply.disable(apply_symbolic_mode) | |||
| @apply.register() | |||
| def apply_compiled_mode(op: OpDef, *args: RawTensor): | |||
| if skip_tracing: | |||
| args = [ | |||
| as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | |||
| for x in args | |||
| ] | |||
| return apply.super(op, *args) | |||
| return active_trace._apply_op(op, args) | |||
| apply.disable(apply_compiled_mode) | |||
| # this hook injects TraceMixin | |||
| @apply.register() | |||
| def apply_with_tracing(op: OpDef, *args: RawTensor): | |||
| outputs = apply.super(op, *args) | |||
| active_trace._record_op(op, args, outputs) | |||
| return outputs | |||
| apply.disable(apply_with_tracing) | |||
| # @apply.register() | |||
| # def _(op: Const, *args: RawTensor): | |||
| # return active_trace._apply_const(op, args) | |||
| class BrokenRawTensor(RawTensor): | |||
| def __getattribute__(self, _): | |||
| raise RuntimeError("broken due to misuse of tracing") | |||
| def __setattr__(self, *_): | |||
| raise RuntimeError("broken due to misuse of tracing") | |||
| @@ -23,10 +23,29 @@ namespace py = pybind11; | |||
| using namespace mgb; | |||
| using namespace imperative; | |||
| namespace { | |||
| template<typename XTensorND> | |||
| auto def_TensorND(py::object parent, const char* name) { | |||
| return py::class_<XTensorND>(parent, name) | |||
| .def_property_readonly("shape", py::overload_cast<>(&XTensorND::shape, py::const_)) | |||
| .def_property_readonly("dtype", py::overload_cast<>(&XTensorND::dtype, py::const_)) | |||
| .def_property_readonly("comp_node", py::overload_cast<>(&XTensorND::comp_node, py::const_)) | |||
| .def("copy_from", &XTensorND::template copy_from<DeviceTensorStorage>) | |||
| .def("copy_from", &XTensorND::template copy_from<HostTensorStorage>) | |||
| .def("copy_from_fixlayout", py::overload_cast<const DeviceTensorND&>( | |||
| &XTensorND::template copy_from_fixlayout<DeviceTensorStorage>)) | |||
| .def("copy_from_fixlayout", py::overload_cast<const HostTensorND&>( | |||
| &XTensorND::template copy_from_fixlayout<HostTensorStorage>)); | |||
| } | |||
| } // namespace | |||
| void init_common(py::module m) { | |||
| py::class_<CompNode>(m, "CompNode") | |||
| auto&& PyCompNode = py::class_<CompNode>(m, "CompNode") | |||
| .def(py::init()) | |||
| .def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) | |||
| .def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) | |||
| .def("__str__", &CompNode::to_string_logical) | |||
| .def_static("_sync_all", &CompNode::sync_all) | |||
| .def(py::self == py::self) | |||
| @@ -40,19 +59,30 @@ void init_common(py::module m) { | |||
| return CompNode::load(cn); | |||
| })); | |||
| py::class_<CompNode::Event, std::shared_ptr<CompNode::Event>>(PyCompNode, "Event") | |||
| .def("record", &CompNode::Event::record) | |||
| .def("wait", &CompNode::Event::host_wait); | |||
| py::implicitly_convertible<std::string, CompNode>(); | |||
| py::class_<DeviceTensorND>(m, "DeviceTensorND") | |||
| .def(py::init()) | |||
| .def_property_readonly("shape", py::overload_cast<>(&DeviceTensorND::shape, py::const_)) | |||
| .def_property_readonly("dtype", py::overload_cast<>(&DeviceTensorND::dtype, py::const_)) | |||
| .def_property_readonly("comp_node", py::overload_cast<>(&DeviceTensorND::comp_node, py::const_)) | |||
| def_TensorND<DeviceTensorND>(m, "DeviceTensorND") | |||
| .def("numpy", [](const DeviceTensorND& self) { | |||
| HostTensorND hv; | |||
| hv.copy_from(self).sync(); | |||
| return py::handle(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); | |||
| }); | |||
| def_TensorND<HostTensorND>(m, "HostTensorND") | |||
| .def(py::init([](py::array data, CompNode cn, DType dtype) { | |||
| if (!cn.valid()) { | |||
| throw py::type_error("device must not be None"); | |||
| } | |||
| return npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); | |||
| })) | |||
| .def("numpy", [](const HostTensorND& self) { | |||
| return py::reinterpret_steal<py::object>(npy::ndarray_from_tensor(self, npy::ShareType::TRY_SHARE)); | |||
| }); | |||
| py::class_<cg::OperatorNodeConfig>(m, "OperatorNodeConfig") | |||
| .def(py::init()) | |||
| .def_property("name", | |||
| @@ -12,6 +12,7 @@ | |||
| #include "./graph_rt.h" | |||
| #include "megbrain/imperative/opr_utility.h" | |||
| #include "megbrain/opr/io.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/imperative.h" | |||
| #include "./helper.h" | |||
| @@ -29,29 +30,44 @@ auto def_rendezvous(py::object m, const char* name) { | |||
| .def(py::init([](){return std::make_shared<Rendezvous<T>>();})) | |||
| .def("set", [](Rendezvous<T>& r, T v) {r.set(std::move(v));}) | |||
| .def("get", [](Rendezvous<T>& r) {return r.get();}, py::call_guard<py::gil_scoped_release>()) | |||
| .def("drop", &Rendezvous<T>::drop) | |||
| .def("reset", &Rendezvous<T>::reset); | |||
| } | |||
| using TensorAttr = LogicalTensorDesc; | |||
| using HostNDWithEvent = std::pair<HostTensorND, std::shared_ptr<CompNode::Event>>; | |||
| void init_graph_rt(py::module m) { | |||
| def_rendezvous<DeviceTensorND>(m, "DeviceTensorNDRendezvous"); | |||
| def_rendezvous<HostNDWithEvent>(m, "HostTensorNDRendezvous"); | |||
| def_rendezvous<TensorAttr>(m, "TensorAttrRendezvous"); | |||
| py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode") | |||
| .def_property_readonly("owner", [](cg::VarNode* v) {return v->owner_opr();}) | |||
| .def_property_readonly("graph", [](cg::VarNode* v) {return v->owner_graph();}) | |||
| .def_property_readonly("name", py::overload_cast<>(&VarNode::name, py::const_)) | |||
| .def_property_readonly("dtype", [](cg::VarNode* v) {return v->dtype();}) | |||
| .def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();}); | |||
| .def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();}) | |||
| .def_property_readonly("shape", [](cg::VarNode* v) -> const TensorShape* { | |||
| auto&& mgr = v->owner_graph()->static_infer_manager(); | |||
| auto&& type = mgr.get_infer_type(v); | |||
| using InferType = cg::static_infer::InferType; | |||
| if (!(type.shape & (InferType::CONST | InferType::RT_STATIC))) { | |||
| return nullptr; | |||
| } | |||
| return mgr.infer_shape_fallible(v); | |||
| }); | |||
| py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode") | |||
| .def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();}) | |||
| .def_property_readonly("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_)) | |||
| .def_property_readonly("inputs", [](cg::OperatorNodeBase* opr) { | |||
| return to_tuple(opr->input()); | |||
| }) | |||
| .def_property_readonly("outputs", [](cg::OperatorNodeBase* opr) { | |||
| return to_tuple(opr->output()); | |||
| return to_tuple(opr->usable_output()); | |||
| }); | |||
| py::class_<cg::AsyncExecutable>(m, "AsyncExecutable") | |||
| @@ -117,7 +133,7 @@ void init_graph_rt(py::module m) { | |||
| common.def("invoke_op", [](const OpDef& def, const std::vector<cg::VarNode*> inputs, cg::ComputingGraph* graph) { | |||
| cg::VarNodeArray vinputs(inputs.begin(), inputs.end()); | |||
| auto opr = OpDef::apply_on_var_node(def, vinputs); | |||
| auto outputs = opr->output(); | |||
| auto outputs = opr->usable_output(); | |||
| return to_tuple(outputs); | |||
| }, | |||
| py::arg(), py::arg(), py::arg("graph") = py::none()); | |||
| @@ -125,6 +141,7 @@ void init_graph_rt(py::module m) { | |||
| auto input_callback = [](auto callback, | |||
| const CompNode& comp_node, | |||
| const DType& dtype, | |||
| const TensorShape& shape, | |||
| const std::vector<cg::VarNode*>& inputs, | |||
| cg::ComputingGraph* graph) { | |||
| if (!graph) { | |||
| @@ -135,7 +152,7 @@ void init_graph_rt(py::module m) { | |||
| sinputs.emplace_back(i); | |||
| } | |||
| static_assert(!std::is_reference<decltype(callback)>::value); | |||
| auto soutputs = opr::InputCallback::make(*graph, std::move(callback), comp_node, dtype, sinputs); | |||
| auto soutputs = opr::InputCallback::make(*graph, std::move(callback), comp_node, dtype, shape, sinputs); | |||
| std::vector<VarNode*> outputs; | |||
| outputs.reserve(soutputs.size()); | |||
| for (auto i : soutputs) { | |||
| @@ -144,26 +161,40 @@ void init_graph_rt(py::module m) { | |||
| return outputs; | |||
| }; | |||
| m.def("make_shared", [](cg::ComputingGraph* graph, const DeviceTensorND& data) { | |||
| return opr::SharedDeviceTensor::make(*graph, std::make_shared<DeviceTensorND>(data)).node(); | |||
| }); | |||
| m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) { | |||
| if (!cn.valid()) { | |||
| throw py::type_error("device must not be None"); | |||
| } | |||
| auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); | |||
| opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); | |||
| }); | |||
| m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback, | |||
| const CompNode& comp_node, | |||
| const DType& dtype, | |||
| const TensorShape& shape, | |||
| const std::vector<cg::VarNode*>& inputs, | |||
| cg::ComputingGraph* graph) { | |||
| return input_callback([f=std::move(callback)](){py::gil_scoped_acquire _; return f();}, comp_node, dtype, inputs, graph); | |||
| return input_callback([f=std::move(callback)](){py::gil_scoped_acquire _; return f();}, comp_node, dtype, shape, inputs, graph); | |||
| }, | |||
| py::arg(), py::arg(), py::arg(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | |||
| py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | |||
| m.def("input_callback", [input_callback](std::shared_ptr<Rendezvous<DeviceTensorND>> p, | |||
| const CompNode& comp_node, | |||
| const DType& dtype, | |||
| const TensorShape& shape, | |||
| const std::vector<cg::VarNode*>& inputs, | |||
| cg::ComputingGraph* graph) { | |||
| auto f = [p]() -> DeviceTensorND { | |||
| return p->get(); | |||
| }; | |||
| return input_callback(std::move(f), comp_node, dtype, inputs, graph); | |||
| return input_callback(std::move(f), comp_node, dtype, shape, inputs, graph); | |||
| }, | |||
| py::arg(), py::arg(), py::arg(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | |||
| py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | |||
| auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs, bool borrow = false) { | |||
| SymbolVarArray sinputs; | |||
| @@ -193,6 +224,17 @@ void init_graph_rt(py::module m) { | |||
| return output_callback(std::move(f), std::move(inputs)); | |||
| }); | |||
| m.def("value_output_callback", [output_callback](std::shared_ptr<Rendezvous<HostNDWithEvent>> p, std::vector<cg::VarNode*> inputs) { | |||
| auto f = [p](DeviceTensorND dv) { | |||
| HostNDWithEvent hv_with_event; | |||
| hv_with_event.first.copy_from(dv); | |||
| hv_with_event.second = dv.comp_node().create_event(); | |||
| hv_with_event.second->record(); | |||
| p->set(std::move(hv_with_event)); | |||
| }; | |||
| return output_callback(std::move(f), std::move(inputs), true); | |||
| }); | |||
| m.def("attr_output_callback", [output_callback](std::shared_ptr<Rendezvous<TensorAttr>> p, std::vector<cg::VarNode*> inputs) { | |||
| auto f = [p](DeviceTensorND dv) { | |||
| p->set(TensorAttr{TensorLayout{dv.shape(), dv.dtype()}, dv.comp_node()}); | |||
| @@ -39,6 +39,7 @@ template<typename R> | |||
| class Rendezvous { | |||
| std::mutex m_lock; | |||
| int m_read_ahead = 0; | |||
| bool m_drop_next = false; | |||
| std::promise<R> m_promise; | |||
| public: | |||
| Rendezvous() = default; | |||
| @@ -47,6 +48,7 @@ public: | |||
| Rendezvous& operator=(const Rendezvous& rhs) = delete; | |||
| Rendezvous& operator=(Rendezvous&& rhs) { | |||
| MGB_LOCK_GUARD(m_lock); | |||
| m_drop_next = rhs.m_drop_next; | |||
| m_read_ahead = rhs.m_read_ahead; | |||
| m_promise = std::move(rhs.m_promise); | |||
| return *this; | |||
| @@ -67,12 +69,28 @@ public: | |||
| return f.get(); | |||
| } | |||
| void drop() { | |||
| MGB_LOCK_GUARD(m_lock); | |||
| mgb_assert(m_read_ahead <= 0); | |||
| mgb_assert(m_read_ahead >= -1); | |||
| if (m_read_ahead == -1) { | |||
| m_promise = {}; | |||
| } else { | |||
| m_drop_next = true; | |||
| } | |||
| ++m_read_ahead; | |||
| } | |||
| template<typename T> | |||
| void set(T&& value) { | |||
| MGB_LOCK_GUARD(m_lock); | |||
| mgb_assert(m_read_ahead >= 0); | |||
| mgb_assert(m_read_ahead <= 1); | |||
| m_promise.set_value(std::forward<T>(value)); | |||
| if (m_drop_next) { | |||
| m_drop_next = false; | |||
| } else { | |||
| m_promise.set_value(std::forward<T>(value)); | |||
| } | |||
| if (m_read_ahead == 1) { | |||
| m_promise = {}; | |||
| } | |||
| @@ -83,6 +101,7 @@ public: | |||
| MGB_LOCK_GUARD(m_lock); | |||
| m_promise = {}; | |||
| m_read_ahead = 0; | |||
| m_drop_next = false; | |||
| } | |||
| }; | |||
| @@ -280,9 +280,12 @@ namespace detail { | |||
| public: | |||
| bool load(handle src, bool convert) { | |||
| auto obj = reinterpret_steal<object>(src); | |||
| if (!isinstance<tuple>(obj)) { | |||
| if (!convert && !isinstance<tuple>(obj)) { | |||
| return false; | |||
| } | |||
| if (obj.is_none()) { | |||
| return true; | |||
| } | |||
| value.ndim = len(obj); | |||
| mgb_assert(value.ndim <= mgb::TensorShape::MAX_NDIM); | |||
| size_t i = 0; | |||
| @@ -63,6 +63,7 @@ void init_imperative_rt(py::module m) { | |||
| return self.put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); | |||
| } | |||
| }, py::arg(), py::arg("dtype") = py::none(), py::arg("device") = py::none()) | |||
| .def("put", py::overload_cast<const DeviceTensorND&>(&Interpreter::Channel::put)) | |||
| .def("delete", [](Interpreter::Channel& self, Interpreter::Handle handle) { | |||
| return self.del(handle); | |||
| }) | |||
| @@ -24,6 +24,12 @@ constexpr bool has_fastcall = true; | |||
| constexpr bool has_fastcall = false; | |||
| #endif | |||
| #ifdef _Py_TPFLAGS_HAVE_VECTORCALL | |||
| constexpr bool has_vectorcall = true; | |||
| #else | |||
| constexpr bool has_vectorcall = false; | |||
| #endif | |||
| template<typename... Args> | |||
| struct invocable_with { | |||
| template<typename T> | |||
| @@ -55,6 +61,9 @@ private: | |||
| public: | |||
| PyObject_HEAD | |||
| std::aligned_storage_t<sizeof(T), alignof(T)> storage; | |||
| #ifdef _Py_TPFLAGS_HAVE_VECTORCALL | |||
| PyObject* vectorcall_slot; | |||
| #endif | |||
| inline T* inst() { | |||
| return reinterpret_cast<T*>(&storage); | |||
| @@ -155,6 +164,51 @@ private: | |||
| // polyfills | |||
| struct tp_vectorcall { | |||
| static constexpr bool valid = HAS_MEMBER(T, tp_vectorcall); | |||
| static constexpr bool haskw = [](){if constexpr (valid) | |||
| if constexpr (std::is_invocable_v<T::tp_vectorcall, T, PyObject*const*, size_t, PyObject*>) | |||
| return true; | |||
| return false;}(); | |||
| template<typename = void> | |||
| static PyObject* impl(PyObject* self, PyObject*const* args, size_t nargsf, PyObject *kwnames) { | |||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
| if constexpr (haskw) { | |||
| CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf, kwnames)); | |||
| } else { | |||
| if (kwnames && PyTuple_GET_SIZE(kwnames)) { | |||
| PyErr_SetString(PyExc_TypeError, "expect no keyword argument"); | |||
| return nullptr; | |||
| } | |||
| CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf)); | |||
| } | |||
| } | |||
| static constexpr Py_ssize_t offset = []() {if constexpr (valid) return offsetof(wrap_t, vectorcall_slot); | |||
| else return 0;}(); | |||
| }; | |||
| struct tp_call { | |||
| static constexpr bool provided = HAS_MEMBER(T, tp_call); | |||
| static constexpr bool static_form = invocable_with<T, PyObject*, PyObject*, PyObject*>{}( | |||
| [](auto&& t, auto... args) -> decltype(std::decay_t<decltype(t)>::tp_call(args...)) {}); | |||
| static constexpr bool valid = provided || tp_vectorcall::valid; | |||
| template<typename = void> | |||
| static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { | |||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
| CVT_RET_PYOBJ(inst->tp_call(args, kwargs)); | |||
| } | |||
| static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call; | |||
| else if constexpr (provided) return impl<>; | |||
| #ifdef _Py_TPFLAGS_HAVE_VECTORCALL | |||
| else if constexpr (valid) return PyVectorcall_Call; | |||
| #endif | |||
| else return nullptr;}(); | |||
| }; | |||
| struct tp_new { | |||
| static constexpr bool provided = HAS_MEMBER(T, tp_new); | |||
| static constexpr bool varkw = std::is_constructible_v<T, PyObject*, PyObject*>; | |||
| @@ -163,11 +217,14 @@ private: | |||
| template<typename = void> | |||
| static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) { | |||
| auto* self = type->tp_alloc(type, 0); | |||
| auto* ptr = reinterpret_cast<wrap_t*>(self)->inst(); | |||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
| if constexpr (has_vectorcall && tp_vectorcall::valid) { | |||
| reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>; | |||
| } | |||
| if constexpr (varkw) { | |||
| new(ptr) T(args, kwargs); | |||
| new(inst) T(args, kwargs); | |||
| } else { | |||
| new(ptr) T(); | |||
| new(inst) T(); | |||
| } | |||
| return self; | |||
| } | |||
| @@ -190,22 +247,6 @@ private: | |||
| else return impl<>;}(); | |||
| }; | |||
| struct tp_call { | |||
| static constexpr bool valid = HAS_MEMBER(T, tp_call); | |||
| static constexpr bool static_form = invocable_with<T, PyObject*, PyObject*, PyObject*>{}( | |||
| [](auto&& t, auto... args) -> decltype(std::decay_t<decltype(t)>::tp_call(args...)) {}); | |||
| template<typename = void> | |||
| static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { | |||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
| CVT_RET_PYOBJ(inst->tp_call(args, kwargs)); | |||
| } | |||
| static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call; | |||
| else if constexpr (valid) return impl<>; | |||
| else return nullptr;}(); | |||
| }; | |||
| public: | |||
| class TypeBuilder { | |||
| std::vector<PyMethodDef> m_methods; | |||
| @@ -228,9 +269,17 @@ public: | |||
| m_type.tp_name = T::tp_name; | |||
| } | |||
| m_type.tp_dealloc = tp_dealloc::value; | |||
| #ifdef _Py_TPFLAGS_HAVE_VECTORCALL | |||
| m_type.tp_vectorcall_offset = tp_vectorcall::offset; | |||
| #endif | |||
| m_type.tp_call = tp_call::value; | |||
| m_type.tp_basicsize = sizeof(wrap_t); | |||
| m_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| #ifdef _Py_TPFLAGS_HAVE_VECTORCALL | |||
| if constexpr (tp_vectorcall::valid) { | |||
| m_type.tp_flags |= _Py_TPFLAGS_HAVE_VECTORCALL; | |||
| } | |||
| #endif | |||
| m_type.tp_new = tp_new::value; | |||
| } | |||
| @@ -0,0 +1,65 @@ | |||
| import numpy as np | |||
| from megengine.core.ops import builtin as ops | |||
| from megengine.core.tensor.core import apply | |||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
| from megengine.jit import exclude_from_trace, trace | |||
| def test_trace(): | |||
| for symbolic in [False, True]: | |||
| @trace(symbolic=symbolic) | |||
| def f(x): | |||
| op = ops.Elemwise(mode="negate") | |||
| (y,) = apply(op, x) | |||
| return y | |||
| x = as_raw_tensor([1]).numpy() | |||
| y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||
| for i in range(3): | |||
| np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||
| def test_exclude_from_trace(): | |||
| for symbolic in [False, True]: | |||
| @trace(symbolic=symbolic) | |||
| def f(x): | |||
| neg = ops.Elemwise(mode="negate") | |||
| (x,) = apply(neg, x) | |||
| with exclude_from_trace(): | |||
| if i % 2: | |||
| (x,) = apply(neg, x) | |||
| (x,) = apply(neg, x) | |||
| return x | |||
| x = as_raw_tensor([1]).numpy() | |||
| for i in range(3): | |||
| y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||
| np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||
| def test_print_in_trace(): | |||
| for symbolic in [False]: # cannot read value in symbolic mode | |||
| @trace(symbolic=symbolic) | |||
| def f(x): | |||
| nonlocal buf | |||
| neg = ops.Elemwise(mode="negate") | |||
| (x,) = apply(neg, x) | |||
| buf = x.numpy() | |||
| (x,) = apply(neg, x) | |||
| return x | |||
| buf = None | |||
| x = as_raw_tensor([1]).numpy() | |||
| for i in range(3): | |||
| y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||
| z = buf | |||
| buf = None | |||
| np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||
| np.testing.assert_equal(z, buf) | |||
| @@ -37,6 +37,15 @@ void* ChannelImpl::put(const HostTensorND& value) { | |||
| return info; | |||
| } | |||
| void* ChannelImpl::put(const DeviceTensorND& data) { | |||
| auto info = alloc(); | |||
| info->desc.layout = data.layout(); | |||
| info->desc.comp_node = data.comp_node(); | |||
| info->ptr = Tensor::make(data); | |||
| m_valid_handle.insert(info); | |||
| return info; | |||
| } | |||
| void ChannelImpl::del(void* handle) { | |||
| mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); | |||
| m_worker.add_task(Del{reinterpret_cast<TensorInfo*>(handle)}); | |||
| @@ -55,6 +55,7 @@ struct ChannelImpl : Interpreter::Channel { | |||
| ~ChannelImpl() override; | |||
| Handle put(const HostTensorND& value) override; | |||
| Handle put(const DeviceTensorND& value) override; | |||
| void del(Handle) override; | |||
| @@ -31,9 +31,10 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(InputCallback); | |||
| InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback, | |||
| const VarNodeArray& inputs, | |||
| const TensorShape& output_shape, | |||
| const OperatorNodeConfig& config) | |||
| : Super(&graph, config, "input_callback", inputs), | |||
| m_callback(callback) { | |||
| m_output_shape(output_shape), m_callback(callback) { | |||
| for (VarNode* i : inputs) { | |||
| add_input({i}); | |||
| } | |||
| @@ -48,7 +49,8 @@ InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback, | |||
| SymbolVarArray InputCallback::make(cg::ComputingGraph& graph, | |||
| callback_t callback, CompNode comp_node, | |||
| DType dtype, const SymbolVarArray& inputs) { | |||
| DType dtype, const TensorShape& shape, | |||
| const SymbolVarArray& inputs) { | |||
| mgb_assert(comp_node.valid()); | |||
| mgb_assert(dtype.valid()); | |||
| OperatorNodeConfig config; | |||
| @@ -56,11 +58,22 @@ SymbolVarArray InputCallback::make(cg::ComputingGraph& graph, | |||
| config.output_dtype(dtype); | |||
| auto vinputs = to_var_node_array(inputs); | |||
| auto opr = graph.insert_opr( | |||
| std::make_unique<InputCallback>(graph, callback, vinputs, config)); | |||
| std::make_unique<InputCallback>(graph, callback, vinputs, shape, config)); | |||
| return to_symbol_var_array(opr->output()); | |||
| } | |||
| void InputCallback::init_output_static_infer_desc() {} | |||
| void InputCallback::init_output_static_infer_desc() { | |||
| if (m_output_shape.ndim) { | |||
| using namespace cg::static_infer; | |||
| auto &&mgr = owner_graph()->static_infer_manager(); | |||
| auto infer_shape = [this](TensorShape &dest, const InpVal &) { | |||
| dest = m_output_shape; | |||
| return true; | |||
| }; | |||
| mgr.register_shape_infer(output(0), | |||
| {SourceType::CONSTANT, {}, infer_shape}); | |||
| } | |||
| } | |||
| cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const { | |||
| NodeProp* prop = Super::do_make_node_prop(); | |||
| @@ -73,9 +86,23 @@ cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const { | |||
| void InputCallback::scn_do_execute() { | |||
| auto dev_tensor = m_callback(); | |||
| if (m_output_shape.ndim) { | |||
| mgb_assert(dev_tensor.shape().eq_shape(m_output_shape)); | |||
| } | |||
| output(0)->reset_dev_tensor_from_tensor(dev_tensor); | |||
| } | |||
| cg::OperatorNodeBase* InputCallback::shallow_copy( | |||
| const serialization::OprShallowCopyContext &ctx, | |||
| const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||
| const OperatorNodeConfig &config) { | |||
| auto &&opr = opr_.cast_final_safe<InputCallback>(); | |||
| auto* graph = ctx.owner_graph(opr, inputs); | |||
| return graph->insert_opr(std::make_unique<InputCallback>(*graph, opr.m_callback, inputs, opr.m_output_shape, config)); | |||
| } | |||
| MGB_REG_OPR_SHALLOW_COPY(InputCallback, InputCallback::shallow_copy); | |||
| /* ================ OutputCallback ================== */ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(OutputCallback); | |||
| @@ -122,6 +149,17 @@ void OutputCallback::scn_do_execute() { | |||
| m_param.callback(input(0)->dev_tensor()); | |||
| } | |||
| cg::OperatorNodeBase* OutputCallback::shallow_copy( | |||
| const serialization::OprShallowCopyContext &ctx, | |||
| const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||
| const OperatorNodeConfig &config) { | |||
| auto &&opr = opr_.cast_final_safe<OutputCallback>(); | |||
| auto* graph = ctx.owner_graph(opr, inputs); | |||
| return graph->insert_opr(std::make_unique<OutputCallback>(opr.m_param, inputs, config)); | |||
| } | |||
| MGB_REG_OPR_SHALLOW_COPY(OutputCallback, OutputCallback::shallow_copy); | |||
| /* ================ NopCallback ================== */ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(NopCallback); | |||
| @@ -22,6 +22,7 @@ struct Interpreter { | |||
| virtual ~Channel() = default; | |||
| virtual Handle put(const HostTensorND& value) = 0; | |||
| virtual Handle put(const DeviceTensorND& value) = 0; | |||
| virtual void del(Handle) = 0; | |||
| @@ -17,6 +17,7 @@ | |||
| #include "megbrain/opr/internal/param_tag_defs.h" | |||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
| #include "megbrain/opr/param_defs.h" | |||
| #include "megbrain/serialization/sereg.h" | |||
| #include "megdnn/oprs/utils.h" | |||
| @@ -33,17 +34,24 @@ public: | |||
| InputCallback(cg::ComputingGraph& graph, | |||
| callback_t callback, | |||
| const VarNodeArray& inputs, | |||
| const TensorShape& output_shape, | |||
| const OperatorNodeConfig &config); | |||
| static SymbolVarArray make(cg::ComputingGraph& graph, | |||
| callback_t callback, | |||
| CompNode comp_node, | |||
| DType dtype, | |||
| const TensorShape& shape, | |||
| const SymbolVarArray& inputs = {}); | |||
| static cg::OperatorNodeBase* shallow_copy( | |||
| const serialization::OprShallowCopyContext &ctx, | |||
| const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||
| const OperatorNodeConfig &config); | |||
| protected: | |||
| void scn_do_execute() override; | |||
| void init_output_static_infer_desc() override; | |||
| NodeProp* do_make_node_prop() const override; | |||
| private: | |||
| TensorShape m_output_shape; | |||
| callback_t m_callback; | |||
| }; | |||
| @@ -63,6 +71,10 @@ public: | |||
| SymbolVar input) { | |||
| return make(std::move(param), SymbolVarArray{input}); | |||
| } | |||
| static cg::OperatorNodeBase* shallow_copy( | |||
| const serialization::OprShallowCopyContext &ctx, | |||
| const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||
| const OperatorNodeConfig &config); | |||
| protected: | |||
| void scn_do_execute() override; | |||
| void init_output_static_infer_desc() override; | |||