GitOrigin-RevId: 32dd49a23a
tags/v1.8.0
| @@ -28,9 +28,6 @@ class AttachSpec: | |||
| __slots__ = "tensor", "callbacks" | |||
| _global_priority = 0 | |||
| class GradManager: | |||
| r"""GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode | |||
| automatic differentiation (a.k.a. back propagation). | |||
| @@ -127,7 +124,6 @@ class GradManager: | |||
| self._grad = None | |||
| self._after_backward_callback = [] | |||
| self._gradients = {} | |||
| self._priority = None | |||
| def attached_tensors(self): | |||
| r"""Return attached tensor list from :meth:`attach`.""" | |||
| @@ -299,31 +295,25 @@ class GradManager: | |||
| tensor.grad = grad | |||
| else: | |||
| tensor.grad += grad | |||
| if tensor._isscalar() and tensor.grad is not None: | |||
| tensor.grad._setscalar() | |||
| finally: | |||
| self.release() | |||
| backwarding_grad_manager = cache | |||
| set_option("record_computing_path", 1) | |||
| pop_scope("backward") | |||
| set_option("record_computing_path", 1) | |||
| pop_scope("backward") | |||
| def record(self): | |||
| r"""Start recording operations | |||
| After this call, you will be able to call :meth:`backward`. | |||
| """ | |||
| global _global_priority | |||
| if self._recording: | |||
| raise RuntimeError("already recording") | |||
| grad = Grad() | |||
| self._recording = True | |||
| self._grad = grad | |||
| grad.__enter__() | |||
| for spec in self._attach_specs.values(): | |||
| self._do_record(spec) | |||
| if self._priority is None: | |||
| grad._priority = _global_priority | |||
| _global_priority -= 1 | |||
| grad.__enter__() | |||
| def _do_record(self, spec): | |||
| tensor = spec.tensor() | |||
| @@ -331,6 +321,8 @@ class GradManager: | |||
| return | |||
| def callback(grad, callbacks=spec.callbacks): | |||
| from ..functional import ones_like | |||
| for cb in callbacks: | |||
| grad = cb(tensor, grad) | |||
| self._gradients[id(tensor)] = grad | |||
| @@ -343,14 +335,11 @@ class GradManager: | |||
| After this call, you will not be able to call :meth:`backward`. | |||
| """ | |||
| global _global_priority | |||
| if self._grad is not None: | |||
| self._grad.__exit__(None, None, None) | |||
| self._grad = None | |||
| self._recording = False | |||
| self._gradients = dict() | |||
| if self._priority is None: | |||
| _global_priority += 1 | |||
| def __enter__(self): | |||
| self.record() | |||
| @@ -382,15 +371,14 @@ class GradManagerGroup: | |||
| __ror__ = merge_with | |||
| def __enter__(self): | |||
| global _global_priority | |||
| _global_priority += 1 | |||
| Grad.stack.append([]) | |||
| Grad.begin_group() | |||
| for gm in self._gms: | |||
| gm._priority = _global_priority | |||
| gm.record() | |||
| assert gm._grad is not None | |||
| Grad.end_group() | |||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||
| global _global_priority | |||
| _global_priority -= 1 | |||
| for gm in self._gms: | |||
| for gm in reversed(self._gms): | |||
| gm.release() | |||
| gm._priority = None | |||
| assert gm._grad is None | |||
| @@ -6,17 +6,9 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| import heapq | |||
| import itertools | |||
| import typing | |||
| import weakref | |||
| import numpy as np | |||
| from .._imperative_rt import core2, ops | |||
| from ..ops.builtin import Elemwise, OpDef, RemoteSend | |||
| from ..ops.special import Const | |||
| from .._imperative_rt import core2 | |||
| _grad_count = 0 | |||
| _grad_manager_dict = weakref.WeakValueDictionary() | |||
| @@ -36,6 +28,10 @@ class GradKey(core2.GradKey): | |||
| class Grad: | |||
| stack = [] | |||
| grouping = False | |||
| key2grad = weakref.WeakValueDictionary() | |||
| def __init__(self, name=None): | |||
| global _grad_count | |||
| if name is None: | |||
| @@ -43,15 +39,9 @@ class Grad: | |||
| _grad_count += 1 | |||
| self._refkeeper = [] | |||
| self._impl = GradKey(name) | |||
| Grad.key2grad[self._impl] = self | |||
| _grad_manager_dict[self._name] = self | |||
| @property | |||
| def _priority(self): | |||
| return self._impl.priority | |||
| @_priority.setter | |||
| def _priority(self, priority): | |||
| self._impl.priority = priority | |||
| self._group = [weakref.ref(self)] | |||
| @property | |||
| def _name(self): | |||
| @@ -70,33 +60,80 @@ class Grad: | |||
| if not isinstance(ys, Sequence): | |||
| ys = [ys] | |||
| if not isinstance(dys, Sequence): | |||
| dys = [dys] | |||
| group = [ref() for ref in self._group] | |||
| for grad in group: | |||
| if grad is self: | |||
| continue | |||
| grad.suppress() | |||
| self._impl.backward(ys, dys) | |||
| for grad in group: | |||
| if grad is self: | |||
| continue | |||
| grad.resume() | |||
| self._refkeeper = None | |||
| return None | |||
| def __enter__(self): | |||
| ref = weakref.ref(self) | |||
| self._impl.enter() | |||
| if Grad.grouping: | |||
| group = Grad.stack[-1] | |||
| self._group = group | |||
| group.append(ref) | |||
| else: | |||
| Grad.stack.append(self._group) | |||
| return self | |||
| def __exit__(self, _1, _2, _3): | |||
| self._impl.exit() | |||
| self._refkeeper = None | |||
| del self._impl | |||
| class Function(ops.PyOpBase): | |||
| del Grad.key2grad[self._impl] | |||
| self._impl = None | |||
| self._group.remove(weakref.ref(self)) | |||
| if len(self._group) == 0: | |||
| Grad.stack.remove(self._group) | |||
| @staticmethod | |||
| def begin_group(): | |||
| assert not Grad.grouping | |||
| Grad.grouping = True | |||
| @staticmethod | |||
| def end_group(): | |||
| group = Grad.stack[-1] | |||
| assert len(group) > 0 | |||
| assert Grad.grouping | |||
| Grad.grouping = False | |||
| def suppress(self): | |||
| if self._impl is not None: | |||
| self._impl.suppress() | |||
| def resume(self): | |||
| if self._impl is not None: | |||
| self._impl.resume() | |||
| class Function: | |||
| r"""Defines a block of operations with customizable differentiation. | |||
| The computation should be defined in ``forward`` method, with gradient | |||
| computation defined in ``backward`` method. | |||
| Each instance of ``Function`` should be used only once during forwardding. | |||
| Examples: | |||
| .. code-block:: | |||
| class Sigmoid(Function): | |||
| def forward(self, x): | |||
| y = 1 / (1 + F.exp(-x)) | |||
| @@ -115,7 +152,7 @@ class Function(ops.PyOpBase): | |||
| Returns: | |||
| a tuple of Tensor or a single Tensor. | |||
| Note: | |||
| * This method should return a tuple of Tensor or a single Tensor representing the output | |||
| of the function. | |||
| @@ -128,7 +165,7 @@ class Function(ops.PyOpBase): | |||
| Args: | |||
| output_grads: gradients of outputs that are returned by :meth:`forward`. | |||
| Note: | |||
| * In case when some tensors of outputs are not related to loss function, the corresponding | |||
| values in ``output_grads`` would be ``None``. | |||
| @@ -148,10 +185,40 @@ class Function(ops.PyOpBase): | |||
| return self._default_rule(*args), self.backward | |||
| def __call__(self, *args): | |||
| ret = core2.apply(self, *args) | |||
| for arg in args: | |||
| if not isinstance(arg, core2.Tensor): | |||
| raise TypeError( | |||
| "op Function expect type Tensor as inputs, got {}".format(type(arg)) | |||
| ) | |||
| grad_key = core2.get_grad_key(args) | |||
| if grad_key is None: | |||
| return self._default_rule(*args) | |||
| grad = Grad.key2grad[grad_key] | |||
| group = [ref() for ref in grad._group] | |||
| for grad in group: | |||
| grad.suppress() | |||
| outputs, backward = self._grad_rule(*args) | |||
| for grad in reversed(group): | |||
| grad.resume() | |||
| def normalized_backward(*output_grads): | |||
| input_grads = backward(*output_grads) | |||
| if isinstance(input_grads, core2.Tensor) or input_grads is None: | |||
| input_grads = (input_grads,) | |||
| return input_grads | |||
| if self.__single_output: | |||
| (ret,) = ret | |||
| return ret | |||
| outputs = (outputs,) | |||
| for grad in reversed(group): | |||
| if grad._impl is None: | |||
| continue | |||
| outputs = core2.set_grad(grad._impl, normalized_backward, args, outputs) | |||
| if self.__single_output: | |||
| (outputs,) = outputs | |||
| return outputs | |||
| def __getstate__(self): | |||
| return self.__dict__ | |||
| @@ -26,7 +26,6 @@ from .utils import ( | |||
| convert_inputs, | |||
| isscalar, | |||
| make_shape_tuple, | |||
| setscalar, | |||
| ) | |||
| _ElwMod = builtin.Elemwise.Mode | |||
| @@ -34,14 +33,7 @@ _ElwMod = builtin.Elemwise.Mode | |||
| def _elwise_apply(args, mode): | |||
| op = builtin.Elemwise(mode) | |||
| _isscalar = True | |||
| for i in args: | |||
| if isscalar(i) == False: | |||
| _isscalar = False | |||
| break | |||
| (result,) = apply(op, *args) | |||
| if _isscalar: | |||
| setscalar(result) | |||
| return result | |||
| @@ -203,8 +195,6 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
| op = builtin.RemoveAxis(axis=axis) | |||
| (result,) = apply(op, inp) | |||
| if len(axis) == inp.ndim: | |||
| setscalar(result) | |||
| return result | |||
| @@ -221,6 +211,7 @@ def _reduce(mode): | |||
| op = builtin.Reduce(mode=mode, axis=0) | |||
| (result,) = apply(op, data) | |||
| result = _remove_axis(result, 0) | |||
| elif isinstance(axis, collections.abc.Iterable): | |||
| axis = _normalize_axis(self.ndim, axis, reverse=True) | |||
| for ai in axis: | |||
| @@ -239,8 +230,6 @@ def _reduce(mode): | |||
| if self.dtype == np.bool_: | |||
| if mode in ["min", "max"]: | |||
| result = result.astype("bool") | |||
| if axis is None or self.ndim == 1: | |||
| setscalar(result) | |||
| return result | |||
| return f | |||
| @@ -457,7 +446,6 @@ class ArrayMethodMixin(abc.ABC): | |||
| len(args) == 0 | |||
| ), "transpose for scalar does not accept additional args" | |||
| ret = self.to(self.device) | |||
| setscalar(ret) | |||
| return ret | |||
| if not args: | |||
| args = range(self.ndim)[::-1] | |||
| @@ -111,7 +111,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
| if not isinstance(tuple_val, tuple): | |||
| tuple_val = (tuple_val,) | |||
| ndim_indexed = 0 | |||
| ndim_indexed_scalar = 0 | |||
| for i in tuple_val: | |||
| if not i is Ellipsis: | |||
| ndim_indexed += ( | |||
| @@ -119,14 +118,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
| if hasattr(i, "dtype") and i.dtype == np.bool_ and hasattr(i, "ndim") | |||
| else 1 | |||
| ) | |||
| if isscalar(i): | |||
| ndim_indexed_scalar += 1 | |||
| ret_scalar = False | |||
| try: | |||
| ret_scalar = ndim_indexed_scalar == inp.ndim | |||
| except ValueError: | |||
| # inp.ndim is unknown | |||
| pass | |||
| else: | |||
| if ndim_indexed > inp.ndim: | |||
| raise IndexError( | |||
| @@ -221,7 +212,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
| items.append(item) | |||
| if new_axes: | |||
| raise IndexError("newaxis is not allowed here") | |||
| return inp, tensors, items, use_subtensor, ret_scalar | |||
| return inp, tensors, items, use_subtensor | |||
| def try_condtake(tensor, index): | |||
| @@ -247,14 +238,12 @@ def getitem(tensor, index): | |||
| try_result = try_condtake(tensor, index) | |||
| if len(try_result) == 2: | |||
| return try_result[0] | |||
| tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) | |||
| tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) | |||
| if use_subtensor: | |||
| op = builtin.Subtensor(items=items) | |||
| else: | |||
| op = builtin.IndexingMultiAxisVec(items=items) | |||
| (result,) = apply(op, tensor, *tensors) | |||
| if ret_scalar: | |||
| result._setscalar() | |||
| return result | |||
| @@ -266,7 +255,7 @@ def setitem(tensor, index, value): | |||
| tensor = tensor.reshape(-1) | |||
| if not isinstance(value, (Tensor, SymbolVar)): | |||
| (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor) | |||
| tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | |||
| tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) | |||
| if use_subtensor: | |||
| op = builtin.Subtensor(items=items) | |||
| else: | |||
| @@ -17,6 +17,7 @@ import numpy as np | |||
| from .. import _imperative_rt | |||
| from .._imperative_rt import GraphOptimizeOptions, SerializationFormat | |||
| from .._imperative_rt.core2 import apply | |||
| from .._wrap import as_device | |||
| from ..ops.builtin import OpDef | |||
| @@ -126,9 +127,8 @@ class Graph(_imperative_rt.ComputingGraph): | |||
| class VarNode: | |||
| def __init__(self, node: _imperative_rt.VarNode, isscalar=False): | |||
| def __init__(self, node: _imperative_rt.VarNode): | |||
| self._node = node | |||
| self._isscalar = isscalar | |||
| if hasattr(self.graph, "_var_cache"): | |||
| self.graph._var_cache[node] = self | |||
| @@ -530,9 +530,6 @@ def _unwrap(x): | |||
| def apply_normal_varnode(op: OpDef, *args: VarNode): | |||
| # for PyOp like RemoteSend/Recv | |||
| if getattr(op, "op", None): | |||
| op = op.op | |||
| outputs = _imperative_rt.invoke_op(op, _unwrap(args)) | |||
| return _wrap(outputs) | |||
| @@ -51,10 +51,7 @@ def concatenate(inputs, axis=0, *, device=None): | |||
| def astype(x, dtype): | |||
| dtype = np.dtype(dtype) | |||
| if not is_dtype_equal(x.dtype, dtype): | |||
| isscalar = x._isscalar() | |||
| (x,) = apply(builtin.TypeCvt(dtype=dtype), x) | |||
| if isscalar: | |||
| x._setscalar() | |||
| return x | |||
| @@ -129,13 +126,6 @@ def isscalar(x): | |||
| return np.isscalar(x) | |||
| def setscalar(x): | |||
| if isinstance(x, (Tensor, SymbolVar)): | |||
| x._setscalar() | |||
| else: | |||
| raise NotImplementedError("Unsupport type {}".format(type(x))) | |||
| def astensor1d(x, *reference, dtype=None, device=None): | |||
| """Convert something to 1D tensor. Support following types | |||
| @@ -237,6 +227,7 @@ for name, mode in [ | |||
| ("**", "pow"), | |||
| ("max", "max"), | |||
| ("additive", "add"), | |||
| ("exp", "EXP"), | |||
| ]: | |||
| _opr_map[(name, 2)] = builtin.Elemwise(mode=mode) | |||
| @@ -13,7 +13,7 @@ import numpy as np | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core.autodiff.grad import Function, _grad_manager_dict | |||
| from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
| from ..core.tensor.utils import isscalar, setscalar | |||
| from ..core.tensor.utils import isscalar | |||
| from ..device import get_default_device, what_is_xpu | |||
| from ..tensor import Tensor | |||
| from . import group | |||
| @@ -72,15 +72,6 @@ def collective_comm(inp, mode, group, device): | |||
| ) | |||
| (result,) = apply(op, inp) | |||
| # assume all workers have homogeneous shape | |||
| if mode in ( | |||
| CollectiveComm.Mode.REDUCE_SUM, | |||
| CollectiveComm.Mode.BROADCAST, | |||
| CollectiveComm.Mode.ALL_REDUCE_SUM, | |||
| CollectiveComm.Mode.ALL_REDUCE_MAX, | |||
| CollectiveComm.Mode.ALL_REDUCE_MIN, | |||
| ): | |||
| if isscalar(inp): | |||
| setscalar(result) | |||
| return result | |||
| @@ -190,8 +181,7 @@ def reduce_sum( | |||
| # Rank 0 # output: None | |||
| # Rank 1 # output: Tensor([1]) | |||
| """ | |||
| op = _ReduceSum(group, device) | |||
| (out,) = apply(op, inp) | |||
| out = _ReduceSum(group, device)(inp) | |||
| if group.rank == 0: | |||
| return out | |||
| @@ -258,8 +248,7 @@ def broadcast( | |||
| _bcast_tracer_state(group, inp) | |||
| op = _Broadcast(group, device) | |||
| (out,) = apply(op, inp) | |||
| out = _Broadcast(group, device)(inp) | |||
| return out | |||
| @@ -604,8 +593,7 @@ def gather( | |||
| inp.shape | |||
| ) | |||
| op = _Gather(group, device) | |||
| (out,) = apply(op, inp) | |||
| out = _Gather(group, device)(inp) | |||
| if group.rank == 0: | |||
| if axis == 0: | |||
| @@ -708,8 +696,7 @@ def scatter( | |||
| + [_ for _ in range(axis + 1, inp.ndim + 1)] | |||
| ) | |||
| inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape) | |||
| op = _Scatter(group, device) | |||
| (out,) = apply(op, inp) | |||
| out = _Scatter(group, device)(inp) | |||
| return out | |||
| @@ -832,7 +819,7 @@ class _RemoteRecv(Function): | |||
| self.op = op | |||
| def forward(self, dummy): | |||
| return apply(self.op, dummy) | |||
| return apply(self.op, dummy)[0] | |||
| def backward(self, grad): | |||
| get_client().bcast_val(grad is not None, self.op.key, 2) | |||
| @@ -871,7 +858,7 @@ def remote_send(inp: Tensor, dest_rank: int): | |||
| op.addr, op.port = get_mm_server_addr() | |||
| op.rank_to = dest_rank | |||
| op.backend = _backend() | |||
| (out,) = apply(_RemoteSend(op), inp) | |||
| out = _RemoteSend(op)(inp) | |||
| _save_output_for_autodiff(inp, out) | |||
| @@ -912,11 +899,6 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor | |||
| inp = Tensor(0, device=device) | |||
| _bcast_tracer_state(group, inp) | |||
| _isscalar = False | |||
| if len(shape) == 0: | |||
| shape = (1,) | |||
| _isscalar = True | |||
| op = RemoteRecv() | |||
| op.key = group.key | |||
| op.cn = device | |||
| @@ -926,7 +908,5 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor | |||
| op.rank_from = src_rank | |||
| op.backend = _backend() | |||
| (ret,) = apply(_RemoteRecv(op), inp) | |||
| if _isscalar: | |||
| setscalar(ret) | |||
| ret = _RemoteRecv(op)(inp) | |||
| return ret | |||
| @@ -67,9 +67,6 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): | |||
| op.offsets = offsets | |||
| op.shapes = [s or (1,) for s in shapes] | |||
| outputs = apply(op, inp) | |||
| for s, x in zip(shapes, outputs): | |||
| if not s: | |||
| x._setscalar() | |||
| return outputs | |||
| @@ -1,25 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..core._imperative_rt.core2 import ( | |||
| set_allow_higher_order_directive as _set_allow_higher_order_directive, | |||
| ) | |||
| __all__ = [ | |||
| "enable_higher_order_directive", | |||
| "disable_higher_order_directive", | |||
| ] | |||
| def enable_higher_order_directive(): | |||
| _set_allow_higher_order_directive(True) | |||
| def disable_higher_order_directive(): | |||
| _set_allow_higher_order_directive(False) | |||
| @@ -12,8 +12,5 @@ from ..core.ops.builtin import InplaceAdd | |||
| def _inplace_add_(dest, delta, alpha, beta): | |||
| isscalar = dest._isscalar() | |||
| dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) | |||
| if isscalar: | |||
| dest._setscalar() | |||
| return dest | |||
| @@ -19,7 +19,7 @@ from ..core.ops import builtin | |||
| from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import amp | |||
| from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph | |||
| from ..core.tensor.utils import _normalize_axis, cast_tensors, subgraph | |||
| from ..jit import exclude_from_trace | |||
| from ..tensor import Tensor | |||
| from ..utils.deprecation import deprecated_kwargs_default | |||
| @@ -1149,7 +1149,6 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | |||
| inp1.ndim <= 1 and inp2.ndim <= 1 | |||
| ), "Input tensors for dot must be 1-dimensional or scalar" | |||
| (result,) = apply(op, inp1, inp2) | |||
| setscalar(result) | |||
| return result | |||
| @@ -1200,5 +1199,4 @@ def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor: | |||
| for i in range(len(inps)): | |||
| inps[i]._reset(oups[i]) | |||
| out._setscalar() | |||
| return out | |||
| @@ -35,7 +35,6 @@ from ..core.tensor.utils import ( | |||
| cast_tensors, | |||
| convert_single_value, | |||
| make_shape_tuple, | |||
| setscalar, | |||
| subgraph, | |||
| ) | |||
| from ..device import get_default_device | |||
| @@ -972,13 +972,6 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||
| ) | |||
| axis = sorted(axis) | |||
| assert axis, "axis could not be empty" | |||
| if inp._isscalar(): | |||
| assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0]) | |||
| if len(axis) == 1: | |||
| inp = copy(inp, device=None) | |||
| inp._unsetscalar() | |||
| return inp | |||
| axis = axis[1:] | |||
| op = builtin.AddAxis(axis=axis) | |||
| (result,) = apply(op, inp) | |||
| return result | |||
| @@ -1164,8 +1157,6 @@ def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None): | |||
| if axis is None: | |||
| inp = inp.reshape(-1) # flatten | |||
| axis = 0 | |||
| if inp._isscalar(): | |||
| inp._unsetscalar() | |||
| shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) | |||
| # assume inp.ndim is not changed during trace | |||
| max_axis = len(shape) - 1 | |||
| @@ -6,19 +6,7 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..core._imperative_rt.core2 import ( | |||
| set_cpp_apply_const_with_tracing, | |||
| set_cpp_apply_with_tracing, | |||
| ) | |||
| from .dtr_config import DTRConfig | |||
| from .graph_opt_config import GraphOptimizationConfig | |||
| from .sublinear_memory_config import SublinearMemoryConfig | |||
| from .tracing import ( | |||
| apply_const_with_tracing, | |||
| apply_with_tracing, | |||
| exclude_from_trace, | |||
| trace, | |||
| ) | |||
| set_cpp_apply_with_tracing(apply_with_tracing) | |||
| set_cpp_apply_const_with_tracing(apply_const_with_tracing) | |||
| from .tracing import TraceError, exclude_from_trace, trace | |||
| @@ -111,6 +111,7 @@ class Module(metaclass=ABCMeta): | |||
| # used for profiler and automatic naming | |||
| self._name = None | |||
| self._short_name = None | |||
| @abstractmethod | |||
| def forward(self, inputs): | |||
| @@ -137,7 +138,7 @@ class Module(metaclass=ABCMeta): | |||
| return HookHandler(self._forward_hooks, hook) | |||
| def __call__(self, *inputs, **kwargs): | |||
| AutoNaming.push_scope(self.name if self.name is not None else self._name) | |||
| AutoNaming.push_scope(self.name if self.name is not None else self._short_name) | |||
| for hook in self._forward_pre_hooks.values(): | |||
| modified_inputs = hook(self, inputs) | |||
| if modified_inputs is not None: | |||
| @@ -641,15 +642,43 @@ class Module(metaclass=ABCMeta): | |||
| else: | |||
| if modules is not None and name in modules: | |||
| modules.remove(name) | |||
| for k, v in _expand_structure(name, value): | |||
| if not v._name: | |||
| v._name = k | |||
| elif v._name != k: | |||
| def append_name(prefix, name): | |||
| if prefix is None or prefix == "": | |||
| return name | |||
| return prefix + "." + name | |||
| def set_name(parent, prefix, name, obj): | |||
| if isinstance(obj, Tensor): | |||
| assert obj.name is not None | |||
| if obj.name != "": | |||
| name = obj.name | |||
| full_name = append_name(prefix, name) | |||
| if obj._short_name and obj._short_name != name: | |||
| logger.warning( | |||
| "try setting the submodule `{}` to `{}`'s new attribute `{}`, its name `{}` will remain unchanged".format( | |||
| type(v), type(self), k, v._name | |||
| obj._short_name, type(parent), name, obj._short_name | |||
| ) | |||
| ) | |||
| return | |||
| if isinstance(obj, Tensor): | |||
| obj._prefix = prefix | |||
| obj._name = full_name | |||
| obj._short_name = name | |||
| obj._set_name(obj._name) | |||
| return obj._name | |||
| elif isinstance(obj, Module): | |||
| obj._name = full_name | |||
| obj._short_name = name | |||
| for k, v in obj._flatten(recursive=False, with_key=True): | |||
| set_name(obj, full_name, k, v) | |||
| return obj._name | |||
| else: | |||
| assert False | |||
| for k, v in _expand_structure(name, value): | |||
| prefix = self._name if self._name else self.name | |||
| set_name(self, prefix, k, v) | |||
| super().__setattr__(name, value) | |||
| def __delattr__(self, name: str): | |||
| @@ -14,6 +14,7 @@ from numpy.random import MT19937 | |||
| from .. import Tensor | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core._imperative_rt.core2 import sync as _sync | |||
| from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle | |||
| from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | |||
| from ..core._imperative_rt.ops import ( | |||
| @@ -650,6 +651,10 @@ class RNG: | |||
| def __del__(self): | |||
| if self._handle != 0: | |||
| # RNG op might execute after handle released due to async dispatch, so | |||
| # we need sync before delete a handle to avoid memory leak or | |||
| # use-after-free | |||
| _sync() | |||
| _delete_rng_handle(self._handle) | |||
| @@ -12,7 +12,7 @@ import numpy as np | |||
| from .core._imperative_rt import CompNode | |||
| from .core._imperative_rt.core2 import Tensor as _Tensor | |||
| from .core._imperative_rt.core2 import apply | |||
| from .core._imperative_rt.core2 import apply, set_py_tensor_type | |||
| from .core._trace_option import use_symbolic_shape | |||
| from .core._wrap import as_device | |||
| from .core.ops.builtin import Copy, GetVarShape | |||
| @@ -20,7 +20,6 @@ from .core.tensor.array_method import ArrayMethodMixin | |||
| from .device import _valid_device, get_default_device | |||
| from .logger import get_logger | |||
| from .utils.deprecation import deprecated | |||
| from .utils.naming import AutoNaming | |||
| logger = get_logger(__name__) | |||
| @@ -40,6 +39,10 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
| grad = None | |||
| dmap_callback = None | |||
| _qparams = None | |||
| _custom_name = "" | |||
| _name = None | |||
| _short_name = None | |||
| _prefix = None | |||
| def __new__( | |||
| cls, | |||
| @@ -81,9 +84,15 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
| device: str = None, | |||
| is_const: bool = False, | |||
| no_cache: bool = False, | |||
| name: str = None, | |||
| name: str = "", | |||
| ): | |||
| pass | |||
| if name is None: | |||
| name = "" | |||
| self._custom_name = name | |||
| self._name = name | |||
| self._short_name = name | |||
| self._set_name(self._name) | |||
| self._prefix = None | |||
| @property | |||
| def shape(self) -> Union[tuple, "Tensor"]: | |||
| @@ -151,12 +160,13 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
| @property | |||
| def name(self): | |||
| return self.c_name | |||
| return self._custom_name | |||
| @name.setter | |||
| def name(self, name): | |||
| self.c_name = name | |||
| AutoNaming.record_var_name(self._mixin_handle, name) | |||
| self._custom_name = name | |||
| self._name = self._prefix + "." + name if self._prefix else name | |||
| self._set_name(self._name) | |||
| @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") | |||
| def set_value(self, value): | |||
| @@ -224,6 +234,9 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
| self._qparams = qparams | |||
| set_py_tensor_type(Tensor) | |||
| tensor = Tensor | |||
| @@ -6,7 +6,6 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | |||
| from . import compat | |||
| from ._passes import optimize | |||
| from .pytree import register_supported_type | |||
| @@ -14,14 +13,12 @@ from .tm_config import disable_default_checker, enable_expr_checker | |||
| from .traced_module import ( | |||
| TracedModule, | |||
| _register_all_builtin_module, | |||
| cpp_apply_module_trace, | |||
| register_as_builtin, | |||
| trace_module, | |||
| wrap, | |||
| ) | |||
| _register_all_builtin_module() | |||
| set_cpp_apply_module_trace(cpp_apply_module_trace) | |||
| __all__ = [ | |||
| "register_as_builtin", | |||
| @@ -13,7 +13,6 @@ import numpy as np | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core._imperative_rt.ops import ROIAlign, ROIPooling | |||
| from ..core.ops.builtin import Copy | |||
| from ..core.tensor.utils import isscalar, setscalar | |||
| from ..tensor import Tensor | |||
| from .tm_config import _exclude_from_trace | |||
| @@ -70,8 +69,6 @@ class TracedModuleChecker: | |||
| self.current_node2values()[node] = apply( | |||
| Copy(comp_node=value.device), value | |||
| )[0] | |||
| if isscalar(value): | |||
| setscalar(self.current_node2values()[node]) | |||
| def check_apply_special_cases(self, opdef, num_outputs): | |||
| indexs = list(range(num_outputs)) | |||
| @@ -20,6 +20,7 @@ from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
| from ..core._imperative_rt.core2 import ( | |||
| apply, | |||
| is_tracing_module, | |||
| set_module_trace_hook, | |||
| set_module_tracing, | |||
| unset_module_tracing, | |||
| ) | |||
| @@ -605,8 +606,7 @@ class Apply(Expr): | |||
| def apply_module_trace_hook(cls, opdef, *inputs): | |||
| for i in inputs: | |||
| node = NodeMixin.get(i, None) | |||
| if node is None: # capture as constant | |||
| NodeMixin.wrap_safe(i, Constant.make(i)) | |||
| assert node is not None | |||
| if isinstance(opdef, FakeQuant): | |||
| inp_nodes = [NodeMixin.get(inputs[0])] | |||
| @@ -805,3 +805,12 @@ class Constant(Expr): | |||
| if isinstance(v, _ModuleState): | |||
| state[k] = v.to_module() | |||
| self.__dict__.update(state) | |||
| def _module_trace_capture(value): | |||
| node = Constant.make(value) | |||
| NodeMixin.wrap_safe(value, node) | |||
| return node | |||
| set_module_trace_hook(Apply.apply_module_trace_hook) | |||
| @@ -101,9 +101,7 @@ BUILTIN_TENSOR_WRAP_METHOD = [ | |||
| "requires_grad", | |||
| "_reset", | |||
| "_isscalar", | |||
| "_setscalar", | |||
| "_tuple_shape", | |||
| "_unsetscalar", | |||
| ] | |||
| @@ -43,7 +43,6 @@ from ..core._imperative_rt.core2 import ( | |||
| ) | |||
| from ..core._trace_option import set_symbolic_shape | |||
| from ..core.ops.builtin import Copy | |||
| from ..core.tensor.utils import isscalar, setscalar | |||
| from ..module import Module | |||
| from ..module import external as MExternal | |||
| from ..module.qat import QATModule | |||
| @@ -1295,12 +1294,9 @@ def _wrapped_function(orig_func): | |||
| return orig_func(*args, **kwargs) | |||
| if isinstance(args[1], RawTensor): | |||
| node = NodeMixin.get(inputs[1]) | |||
| is_scalar = isscalar(inputs[1]) | |||
| inputs[1] = apply( | |||
| Copy(comp_node=inputs[1].device), Tensor(inputs[1]) | |||
| )[0] | |||
| if is_scalar: | |||
| setscalar(inputs[1]) | |||
| # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, | |||
| # which will cause they have same _NodeMixin__node in tracing. | |||
| NodeMixin.wrap_safe(inputs[1], node) | |||
| @@ -2468,8 +2464,8 @@ def trace_module( | |||
| try: | |||
| net_name = mod._name if mod._name else mod.__class__.__name__ | |||
| use_sym_shape = set_symbolic_shape(True) | |||
| set_module_tracing() | |||
| set_active_module_tracer(module_tracer(_wrapped_function)) | |||
| set_module_tracing() | |||
| for cls in [Expr, Node]: | |||
| cls._set_next_id(0) | |||
| with active_module_tracer().patcher: | |||
| @@ -2518,9 +2514,9 @@ def trace_module( | |||
| return traced_mod | |||
| finally: | |||
| set_symbolic_shape(use_sym_shape) | |||
| set_active_module_tracer(None) | |||
| unset_module_tracing() | |||
| for t in mod.tensors(recursive=True): | |||
| NodeMixin.clear_node(t) | |||
| for t in inputs: | |||
| NodeMixin.clear_node(t) | |||
| set_active_module_tracer(None) | |||
| @@ -137,6 +137,11 @@ class Profiler(ContextDecorator): | |||
| get_logger().info("process {} generating {}".format(self._pid, format)) | |||
| self._dump_callback(path, format) | |||
| get_logger().info("profiling results written to {}".format(path)) | |||
| if os.path.getsize(path) > 64 * 1024 * 1024: | |||
| get_logger().warning( | |||
| "profiling results too large, maybe you are profiling multi iters," | |||
| "consider attach profiler in each iter separately" | |||
| ) | |||
| self._dump_callback = None | |||
| _living_profilers.remove(self) | |||
| @@ -9,9 +9,8 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma GCC diagnostic ignored "-Wmissing-field-initializers" | |||
| #include "./grad.h" | |||
| #include "megbrain/imperative/backward_graph_opt.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||
| @@ -19,465 +18,19 @@ | |||
| #include "range/v3/all.hpp" | |||
| #include "./transformation.h" | |||
| namespace py = pybind11; | |||
| namespace views = ranges::views; | |||
| namespace mgb::imperative::python { | |||
| using scoped_disable = ApplyContext::scoped_disable; | |||
| using Flags = Tensor::Flags; | |||
| namespace { | |||
| struct GradSlotWeakPtr { | |||
| std::weak_ptr<GradFn> grad_fn; | |||
| size_t idx; | |||
| }; | |||
| std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | |||
| ApplyContext& ctx, const apply_result_t& outputs) { | |||
| // hash | |||
| using OptimizedBackwardGraphCache = OpMethResultCache< | |||
| std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>; | |||
| thread_local OptimizedBackwardGraphCache cache; | |||
| decltype(cache)::key_t cache_key{ctx.op}; | |||
| SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs; | |||
| SmallVector<bool>& input_requires_grad = std::get<0>(cache_key.extras); | |||
| input_descs.resize(ctx.nargs); | |||
| input_requires_grad.resize(ctx.nargs); | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| input_descs[i].layout.dtype = ctx.args[i]->dtype(); | |||
| input_descs[i].comp_node = ctx.args[i]->comp_node(); | |||
| input_requires_grad[i] = python::input_requires_grad(ctx, i); | |||
| } | |||
| auto iter = cache.find(cache_key); | |||
| if (iter != cache.end()) { | |||
| return iter->second; | |||
| } | |||
| // slow path | |||
| SmallVector<bool> output_has_grad(outputs.size(), true); | |||
| std::shared_ptr<OptimizedBackwardGraphResult> ret; | |||
| auto bg = OpDef::make_backward_graph( | |||
| *ctx.op, input_descs, input_requires_grad, output_has_grad); | |||
| if (!bg.graph.empty()) { | |||
| ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | |||
| } | |||
| cache.emplace(cache_key, ret); | |||
| return ret; | |||
| std::unordered_map<std::shared_ptr<GradKey>, GradKeyWrapper*> grad_key_map; | |||
| } | |||
| struct BackwardGraphWithClosure { | |||
| std::shared_ptr<OptimizedBackwardGraphResult> backward_graph; | |||
| SmallVector<std::shared_ptr<Tensor>> closure; | |||
| size_t output_mask_offset; | |||
| size_t grad_mask_offset; | |||
| BackwardGraphWithClosure( | |||
| std::shared_ptr<OptimizedBackwardGraphResult> backward_graph_, | |||
| ApplyContext& ctx, const apply_result_t& outputs) | |||
| : backward_graph(backward_graph_), | |||
| output_mask_offset(ctx.nargs), | |||
| grad_mask_offset(ctx.nargs + outputs.size()) { | |||
| // save_for_backward[0:nargs]: | |||
| // whether input is kept for backward | |||
| // | |||
| // save_for_backward[nargs:nargs+outputs.size()]: | |||
| // whether output is kept for backward | |||
| // | |||
| // save_for_backward[-outputs.size():]: | |||
| // whether gradient of output can propagate to any input | |||
| // | |||
| // Example: | |||
| // perform c = a * b, with a.requires_grad == True and | |||
| // b.requires_grad == False, save_for_backward = [0, 1, 0, 1] | |||
| auto& save_for_backward = backward_graph->save_for_backward; | |||
| mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size()); | |||
| size_t count = std::count_if( | |||
| save_for_backward.begin(), save_for_backward.end(), ranges::identity{}); | |||
| if (!backward_graph->precomp.empty()) { | |||
| auto&& irng = ranges::span(ctx.args, ctx.nargs); | |||
| auto&& orng = views::transform(outputs, [](auto&& i) { return i.get(); }); | |||
| auto precomp = apply(backward_graph->precomp, views::concat(irng, orng)); | |||
| closure.reserve(precomp.size() + count); | |||
| std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure)); | |||
| } else { | |||
| closure.reserve(count); | |||
| } | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| if (save_for_backward[i]) { | |||
| closure.push_back(ctx.args[i]->shared_from_this()); | |||
| } | |||
| } | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| if (save_for_backward[ctx.nargs + i]) { | |||
| closure.push_back(outputs[i]); | |||
| } | |||
| } | |||
| } | |||
| template <typename T, typename R> | |||
| void operator()(BackwardContext&, T&& grads, R&& receiver) { | |||
| Tensor* args[closure.size() + grads.size()]; | |||
| size_t nargs = 0; | |||
| for (auto&& t : closure) { | |||
| args[nargs++] = t.get(); | |||
| } | |||
| bool null_grad = false; | |||
| for (size_t i = 0; i < grads.size(); ++i) { | |||
| if (backward_graph->save_for_backward[grad_mask_offset + i]) { | |||
| if (grads[i]) { | |||
| if (null_grad) { | |||
| PyErr_SetString(PyExc_NotImplementedError, "report to devs"); | |||
| throw py::error_already_set(); | |||
| } | |||
| args[nargs++] = grads[i]; | |||
| } else { | |||
| null_grad = true; | |||
| } | |||
| } | |||
| } | |||
| if (null_grad) | |||
| return; | |||
| auto igrads = apply(backward_graph->backward, args, nargs); | |||
| auto&& it = igrads.begin(); | |||
| for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) { | |||
| if (p) { | |||
| receiver(i, std::move(*it)); | |||
| ++it; | |||
| } | |||
| } | |||
| } | |||
| bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; } | |||
| bool output_requires_grad(size_t i) { | |||
| return backward_graph->save_for_backward[grad_mask_offset + i]; | |||
| } | |||
| bool output_captured(size_t i) { | |||
| return backward_graph->save_for_backward[output_mask_offset + i]; | |||
| } | |||
| }; | |||
| struct PythonBackward { | |||
| py::object pyfunc; | |||
| size_t input_size; | |||
| PythonBackward(py::object f, size_t nin) : pyfunc(f), input_size(nin) {} | |||
| template <typename T, typename R> | |||
| void operator()(BackwardContext& ctx, T&& grads, R&& receiver) { | |||
| auto args = py::tuple(grads.size()); | |||
| for (size_t i = 0; i < grads.size(); ++i) { | |||
| auto&& g = grads[i]; | |||
| args[i] = g ? ctx.wrap_tensor(g) : py::none(); | |||
| } | |||
| auto input_grads = py::reinterpret_steal<py::object>( | |||
| PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr)); | |||
| if (!input_grads) | |||
| throw py::error_already_set(); | |||
| if (input_grads.is_none()) | |||
| return; | |||
| if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) { | |||
| if (input_size != 1) { | |||
| throw py::value_error( | |||
| "custom grad rule returned wrong number of grads"); | |||
| } | |||
| if (!ctx.pytype) { | |||
| ctx.pytype = Py_TYPE(input_grads.ptr()); | |||
| } | |||
| receiver(0, tw->m_tensor); | |||
| return; | |||
| } | |||
| if (py::len(input_grads) != input_size) { | |||
| throw py::value_error("custom grad rule returned wrong number of grads"); | |||
| } | |||
| for (auto [i, g] : views::enumerate(input_grads)) { | |||
| if (g.is_none()) | |||
| continue; | |||
| auto* tw = TensorWrapper::try_cast(g.ptr()); | |||
| if (!tw) { | |||
| throw py::type_error("custom grad rule returned non-tensor"); | |||
| } | |||
| if (!ctx.pytype) { | |||
| ctx.pytype = Py_TYPE(g.ptr()); | |||
| } | |||
| receiver(i, tw->m_tensor); | |||
| } | |||
| } | |||
| static constexpr bool input_has_grad(size_t) { return true; } | |||
| static constexpr bool output_requires_grad(size_t) { return true; } | |||
| static constexpr bool output_captured(size_t) { return true; } | |||
| }; | |||
| } // namespace | |||
| struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> { | |||
| using Base = intrusive_list::Node<GradProducerRecord>; | |||
| GradProducerRecord() = default; | |||
| GradProducerRecord(GradProducerRecord::head_t& head) | |||
| : Base(intrusive_list::after_t{}, head) {} | |||
| // GradProducerRecord(GradProducerRecord&&) = default; | |||
| // GradProducerRecord& operator=(GradProducerRecord&) = default; | |||
| // GradProducerRecord& operator=(GradProducerRecord&&) = default; | |||
| }; | |||
| struct GradSlot { | |||
| std::shared_ptr<Tensor> grad; | |||
| py::object callback; | |||
| GradProducerRecord::head_t producer_head; | |||
| }; | |||
| struct GradSlotProducerPtr : GradSlotPtr { | |||
| GradProducerRecord producer_record; | |||
| GradSlotProducerPtr() = default; | |||
| GradSlotProducerPtr(GradInfo& info) | |||
| : GradSlotPtr(info), producer_record(info->producer_head) {} | |||
| }; | |||
| struct GradFn : std::enable_shared_from_this<GradFn> { | |||
| static MemPool<GradFn> pool; | |||
| std::weak_ptr<GradKey> key; | |||
| // slots for receiving and accumulating grads | |||
| // same length as outputs (of forward op) | |||
| SmallVector<GradSlot> slots; | |||
| // where to send and accumulate grads | |||
| // same length as inputs (of forward op) | |||
| SmallVector<GradSlotProducerPtr> dsts; | |||
| // encapsules actual function to compute gradient | |||
| std::variant< | |||
| std::monostate, BackwardGraphWithClosure, PythonBackward, CustomBackward> | |||
| backward; | |||
| // a flag used during backward | |||
| bool in_ref_keeper = false; | |||
| static void deleter(GradFn* ptr) { pool.free(ptr); } | |||
| static std::shared_ptr<GradFn> make() { | |||
| return std::shared_ptr<GradFn>(pool.alloc(), &deleter); | |||
| } | |||
| void clear() { | |||
| key.reset(); | |||
| slots.clear(); | |||
| dsts.clear(); | |||
| backward.emplace<std::monostate>(); | |||
| } | |||
| }; | |||
| GradSlotPtr::operator bool() const { | |||
| return bool(grad_fn); | |||
| } | |||
| GradSlot* GradSlotPtr::operator->() { | |||
| return &grad_fn->slots[idx]; | |||
| } | |||
| namespace { | |||
| class GradFnHelper { | |||
| std::shared_ptr<GradFn> grad_fn; | |||
| GradFn* get() { | |||
| if (!grad_fn) { | |||
| grad_fn = std::make_shared<GradFn>(); | |||
| } | |||
| return grad_fn.get(); | |||
| } | |||
| friend apply_result_t imperative::python::apply_grad(ApplyContext&); | |||
| public: | |||
| template <typename T, typename... Args> | |||
| auto& emplace(Args&&... args) { | |||
| return get()->backward.emplace<T>(std::forward<Args>(args)...); | |||
| } | |||
| void reset() { grad_fn = nullptr; } | |||
| }; | |||
| apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | |||
| // copy inputs first, or trace will make InputNodes for each usage | |||
| ApplyContext ctx_dup = ctx; | |||
| SmallVector<std::shared_ptr<Tensor>> inputs_copy; | |||
| SmallVector<Tensor*> inputs_copy_weak; | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| Tensor* input = ctx.args[i]; | |||
| inputs_copy.push_back(python::apply(FastpathCopy::make(), input)[0]); | |||
| inputs_copy_weak.push_back(inputs_copy.back().get()); | |||
| inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict; | |||
| if (input->m_flags & Flags::GRAD) { | |||
| inputs_copy.back()->m_flags |= Flags::GRAD; | |||
| } | |||
| } | |||
| ctx_dup.args = inputs_copy_weak.data(); | |||
| auto outputs = apply(ctx_dup); | |||
| auto backward_graph = make_backward_graph(ctx_dup, outputs); | |||
| if (!backward_graph) { | |||
| return outputs; | |||
| } | |||
| ret_grad_fn.emplace<BackwardGraphWithClosure>( | |||
| std::move(backward_graph), ctx_dup, outputs); | |||
| return outputs; | |||
| } | |||
| apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | |||
| auto* op = ctx.op->try_cast_final<GenericPyOp>(); | |||
| py::tuple pyin(ctx.nargs); | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); | |||
| } | |||
| auto grad_rule = py::getattr(op->obj, "_grad_rule"); | |||
| auto pyret = py::reinterpret_steal<py::object>( | |||
| PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr)); | |||
| if (!pyret) | |||
| throw py::error_already_set(); | |||
| auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret); | |||
| ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs); | |||
| if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) { | |||
| return {tw->m_tensor}; | |||
| } | |||
| apply_result_t ret; | |||
| ret.reserve(py::len(outputs)); | |||
| for (auto&& i : outputs) { | |||
| auto* tw = TensorWrapper::try_cast(i.ptr()); | |||
| mgb_assert(tw); | |||
| ret.push_back(tw->m_tensor); | |||
| } | |||
| return ret; | |||
| } | |||
| } // namespace | |||
| apply_result_t apply_grad(ApplyContext& ctx) { | |||
| std::unordered_set<std::shared_ptr<GradKey>> grad_keys; | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| auto* tensor = ctx.args[i]; | |||
| if (!tensor->m_grad_info_dict.empty()) { | |||
| size_t grad_cnt = 0; | |||
| for (auto&& grad_info : tensor->m_grad_info_dict) { | |||
| auto input_grad_key = grad_info.grad_fn->key.lock(); | |||
| if (input_grad_key && input_grad_key->active && | |||
| !input_grad_key->is_blocked()) { | |||
| grad_keys.insert(input_grad_key); | |||
| grad_cnt++; | |||
| } | |||
| } | |||
| if (!grad_cnt) { | |||
| tensor->m_flags &= ~Flags::GRAD; | |||
| } | |||
| } else { | |||
| tensor->m_flags &= ~Flags::GRAD; | |||
| } | |||
| } | |||
| ctx.flags &= ~Flags::GRAD; | |||
| if (grad_keys.empty()) { | |||
| return apply(ctx); | |||
| } else if (grad_keys.size() > 1 && !GradKey::allow_higher_order_directive) { | |||
| PyErr_SetString( | |||
| PyExc_NotImplementedError, | |||
| "second order directive not enabled, please call " | |||
| "'megengine.experimental.enable_higher_order_directive'"); | |||
| throw pyext17::py_err_set(); | |||
| } | |||
| GradFnHelper grad_fn_holder; | |||
| auto outputs = [&]() { | |||
| auto _ = scoped_disable(Flags::GRAD); | |||
| if (ctx.op->same_type<GenericPyOp>()) { | |||
| return python_grad_rule(ctx, grad_fn_holder); | |||
| } | |||
| auto&& registry = grad_rule_registry(); | |||
| auto&& it = registry.find(ctx.op->dyn_typeinfo()); | |||
| if (it != registry.end()) { | |||
| auto&& maker = grad_fn_holder.emplace<CustomBackward>().maker(ctx); | |||
| if (auto ret = it->second(ctx, maker)) { | |||
| maker.finalize(); | |||
| return *ret; | |||
| } | |||
| grad_fn_holder.reset(); | |||
| } | |||
| return backward_graph_grad_rule(ctx, grad_fn_holder); | |||
| }(); | |||
| if (!grad_fn_holder.grad_fn) { | |||
| return outputs; | |||
| } | |||
| for (auto&& grad_key : grad_keys) { | |||
| auto grad_fn = std::make_shared<GradFn>(); | |||
| grad_fn->backward = grad_fn_holder.grad_fn->backward; | |||
| grad_fn->key = grad_key; | |||
| grad_fn->slots.resize(outputs.size()); | |||
| grad_fn->dsts.reserve(ctx.nargs); | |||
| std::visit( | |||
| [&](auto& backward) { | |||
| using T = std::decay_t<decltype(backward)>; | |||
| if constexpr (std::is_same_v<T, std::monostate>) { | |||
| mgb_assert(0); | |||
| } else { | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| if (backward.input_has_grad(i) && | |||
| input_requires_grad(ctx, i) && | |||
| ctx.args[i]->m_grad_info_dict.count(grad_key.get())) { | |||
| auto& input_grad_info = | |||
| ctx.args[i]->m_grad_info_dict.at( | |||
| grad_key.get()); | |||
| grad_fn->dsts.emplace_back(input_grad_info); | |||
| // register as grad producer | |||
| grad_fn->dsts.back().producer_record.insert_after( | |||
| input_grad_info->producer_head); | |||
| } else { | |||
| grad_fn->dsts.emplace_back(); | |||
| } | |||
| } | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| if (backward.output_requires_grad(i)) { | |||
| if (backward.output_captured(i)) { | |||
| // avoid reference cycle [Tensor <-> GradFn] | |||
| static std::shared_ptr<OpDef> op = | |||
| std::make_shared<FastpathCopy>(); | |||
| outputs[i] = python::apply(op, outputs[i])[0]; | |||
| } | |||
| // populate grad info of output tensor | |||
| auto& grad_info = | |||
| outputs[i]->m_grad_info_dict[grad_key.get()]; | |||
| grad_info.grad_fn = grad_fn; | |||
| grad_info.idx = i; | |||
| grad_info.insert_after(grad_key->free_vars_head); | |||
| outputs[i]->m_flags |= Flags::GRAD; | |||
| } | |||
| } | |||
| } | |||
| }, | |||
| grad_fn->backward); | |||
| // record forward history | |||
| grad_key->tape.emplace_back(grad_fn); | |||
| } | |||
| return outputs; | |||
| } | |||
| PyObject* GradKeyWrapper::get_priority() { | |||
| return py::cast(m_key->priority).release().ptr(); | |||
| } | |||
| void GradKeyWrapper::set_priority(pybind11::handle priority) { | |||
| m_key->priority = py::cast<int>(priority); | |||
| GradKeyWrapper::GradKeyWrapper() : m_key(std::make_shared<GradKey>()) { | |||
| grad_key_map[m_key] = this; | |||
| } | |||
| void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { | |||
| @@ -488,157 +41,59 @@ void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { | |||
| if (!tw) { | |||
| throw py::type_error("argument 1 must be Tensor"); | |||
| } | |||
| auto* tensor = tw->m_tensor.get(); | |||
| py::object callback; | |||
| if (args[1] != Py_None) { | |||
| callback = py::reinterpret_borrow<py::object>(args[1]); | |||
| } | |||
| m_key->attach(tensor, std::move(callback)); | |||
| } | |||
| //! GradKey is weakly refered by tensor->m_grad_info.grad_fn->key after attach | |||
| void GradKey::attach(Tensor* tensor, pybind11::object callback) { | |||
| if (!active) { | |||
| throw py::value_error("grad key finalized"); | |||
| } | |||
| if (tensor->m_grad_info_dict.count(this)) { | |||
| if (tensor->m_grad_info_dict.at(this)->callback) { | |||
| throw py::value_error("callback already set on this tensor"); | |||
| GenericFunction generic_callback = | |||
| [=](Span<ValueRef> inputs) -> std::vector<ValueRef> { | |||
| mgb_assert(inputs.size() == 1); | |||
| if (callback) { | |||
| callback(TensorWrapper::make(py_tensor_type, inputs[0])); | |||
| } | |||
| } else { | |||
| auto& grad_info = tensor->m_grad_info_dict[this]; | |||
| grad_info.idx = 0; | |||
| auto& grad_fn = grad_info.grad_fn; | |||
| grad_fn = std::make_shared<GradFn>(); | |||
| grad_fn->key = shared_from_this(); | |||
| grad_fn->slots.resize(1); | |||
| grad_info.insert_after(free_vars_head); | |||
| tensor->m_flags |= Flags::GRAD; | |||
| } | |||
| tensor->m_grad_info_dict.at(this).grad_fn->slots[0].callback = std::move(callback); | |||
| } | |||
| template <typename T> | |||
| void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) { | |||
| if (!grad) { | |||
| grad = std::forward<T>(delta); | |||
| return; | |||
| } | |||
| static std::shared_ptr<OpDef> op = | |||
| std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD)); | |||
| grad = apply(op, grad, std::forward<T>(delta))[0]; | |||
| return {}; | |||
| }; | |||
| tw->m_tensor->reset(imperative::apply( | |||
| AttachGrad(m_key), tw->m_tensor->data(), | |||
| FunctionValue::make(generic_callback))[0]); | |||
| } | |||
| void GradKey::backward( | |||
| std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) { | |||
| if (!active) { | |||
| throw py::value_error("finalized"); | |||
| void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list grads) { | |||
| std::vector<ValueRef> args; | |||
| mgb_assert(tensors.size() == grads.size()); | |||
| for (auto&& tensor : tensors) { | |||
| args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); | |||
| } | |||
| if (tensors.size() != grads.size()) { | |||
| throw py::value_error("tensor and grad size mismatch"); | |||
| for (auto&& grad : grads) { | |||
| args.push_back(TensorWrapper::try_cast(grad.ptr())->m_tensor->data()); | |||
| } | |||
| // this GradKey is marked inactive here | |||
| active = false; | |||
| struct CleanupGuard { | |||
| GradKey* owner; | |||
| size_t priority_backup; | |||
| CleanupGuard(GradKey* this_) : owner(this_) { | |||
| priority_backup = sm_min_priority; | |||
| sm_min_priority = owner->priority + 1; | |||
| } | |||
| ~CleanupGuard() { | |||
| owner->cleanup(); | |||
| sm_min_priority = priority_backup; | |||
| } | |||
| } _cleanup_guard(this); | |||
| if (tape.empty()) | |||
| return; | |||
| BackwardContext bctx; | |||
| if (!grads.empty()) { | |||
| bctx.pytype = Py_TYPE(grads[0]->self().ptr()); | |||
| } | |||
| for (size_t i = 0; i < tensors.size(); ++i) { | |||
| if (tensors[i]->m_tensor->m_grad_info_dict.count(this) == 0) { | |||
| continue; | |||
| } | |||
| auto& grad_info = tensors[i]->m_tensor->m_grad_info_dict.at(this); | |||
| grad_info->grad = grads[i]->m_tensor; | |||
| } | |||
| std::vector<std::shared_ptr<GradFn>> ref_keeper; | |||
| ref_keeper.reserve(tape.size()); | |||
| // back-propagation in reverse order | |||
| for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { | |||
| auto&& grad_fn = tape[k].lock(); | |||
| if (!grad_fn) | |||
| continue; | |||
| auto grad_receiver = [&](size_t i, auto&& g) { | |||
| auto& dst = grad_fn->dsts[i]; | |||
| if (dst) { | |||
| accum_grad(dst->grad, std::forward<decltype(g)>(g)); | |||
| } | |||
| }; | |||
| std::visit( | |||
| [&](auto&& backward) { | |||
| using T = std::decay_t<decltype(backward)>; | |||
| if constexpr (std::is_same_v<T, std::monostate>) { | |||
| mgb_assert(0); | |||
| } else { | |||
| auto&& grads = views::transform( | |||
| grad_fn->slots, | |||
| [](auto&& slot) { return slot.grad.get(); }); | |||
| backward( | |||
| bctx, std::forward<decltype(grads)>(grads), | |||
| grad_receiver); | |||
| } | |||
| }, | |||
| grad_fn->backward); | |||
| for (auto&& dst : grad_fn->dsts) { | |||
| if (!dst.grad_fn) | |||
| continue; | |||
| if (!dst.grad_fn->in_ref_keeper) { | |||
| // after grad_fn is cleared, refcnt of subsequent grad_fn | |||
| // could drop to 0 | |||
| dst.grad_fn->in_ref_keeper = true; | |||
| ref_keeper.push_back(dst.grad_fn); | |||
| } | |||
| if (!dst.producer_record.next && dst->callback && dst->grad) { | |||
| // I'm the last grad producer, invoke callback | |||
| dst->callback(bctx.wrap_tensor(dst->grad)); | |||
| } | |||
| } | |||
| grad_fn->clear(); | |||
| } // finish tape loop | |||
| imperative::apply(GradBackward(self->m_key), {args.data(), args.size()}); | |||
| } | |||
| void GradKey::cleanup() { | |||
| active = false; | |||
| tape.clear(); | |||
| for (intrusive_list::Iterator it(free_vars_head); it;) { | |||
| it->grad_fn.reset(); | |||
| (it++)->unlink(); | |||
| pybind11::function GradKeyWrapper::get_backward_closure( | |||
| GradKeyWrapper* self, py::list tensors) { | |||
| std::vector<ValueRef> args; | |||
| for (auto&& tensor : tensors) { | |||
| args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); | |||
| } | |||
| } | |||
| void GradKeyWrapper::backward( | |||
| std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) { | |||
| m_key->backward(std::move(tensors), std::move(grads)); | |||
| auto closure = imperative::apply(GetBackwardColsure(self->m_key), args)[0] | |||
| .as<FunctionValue>(); | |||
| auto py_function = [closure](std::vector<TensorWrapper*> tensors) { | |||
| std::vector<ValueRef> args; | |||
| for (auto* tw : tensors) { | |||
| args.push_back(tw->m_tensor->data()); | |||
| } | |||
| (*closure)(args); | |||
| }; | |||
| return pybind11::cpp_function(py_function); | |||
| } | |||
| PyObject* GradKeyWrapper::get_name() { | |||
| return py::cast(m_key->name).release().ptr(); | |||
| return py::cast(m_key->name()).release().ptr(); | |||
| } | |||
| void GradKeyWrapper::set_name(py::handle name) { | |||
| m_key->name = py::cast<std::string>(name); | |||
| m_key->name(py::cast<std::string>(name)); | |||
| } | |||
| PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { | |||
| @@ -651,60 +106,39 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { | |||
| PyErr_SetString(PyExc_TypeError, "expect Tensor"); | |||
| return nullptr; | |||
| } | |||
| if (tw->m_tensor->m_grad_info_dict.count(m_key.get())) { | |||
| if (imperative::apply(IsAttachedTo(m_key), tw->m_tensor->data())[0] | |||
| .cast<BoolValue>()) { | |||
| Py_RETURN_TRUE; | |||
| } | |||
| Py_RETURN_FALSE; | |||
| } | |||
| int GradKey::sm_min_priority = std::numeric_limits<int>::min(); | |||
| GradKey::~GradKey() { | |||
| cleanup(); | |||
| void GradKeyWrapper::enter() { | |||
| m_transformation = std::make_shared<GradTransformation>(m_key); | |||
| TransformationManager::get_instance().register_at<TransformationManager::Grad>( | |||
| m_transformation); | |||
| } | |||
| std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() { | |||
| static std::unordered_map<Typeinfo*, GradRuleFn> registry; | |||
| return registry; | |||
| void GradKeyWrapper::exit() { | |||
| TransformationManager::get_instance().unregister<TransformationManager::Grad>( | |||
| m_transformation); | |||
| m_transformation.reset(); | |||
| } | |||
| void GradInfoCollection::_shrink() { | |||
| auto pred = [](GradInfo& info) { | |||
| return !(info.grad_fn) || info.grad_fn->key.expired(); | |||
| }; | |||
| auto iter = std::remove_if(m_storage.begin(), m_storage.end(), pred); | |||
| m_storage.erase(iter, m_storage.end()); | |||
| void GradKeyWrapper::suppress() { | |||
| m_transformation->suppress(); | |||
| } | |||
| bool GradInfoCollection::contains(GradKey* key) { | |||
| _shrink(); | |||
| for (auto&& grad_info : m_storage) { | |||
| if (grad_info.grad_fn->key.lock().get() == key) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| void GradKeyWrapper::resume() { | |||
| m_transformation->resume(); | |||
| } | |||
| GradInfo& GradInfoCollection::operator[](GradKey* key) { | |||
| _shrink(); | |||
| for (auto&& grad_info : m_storage) { | |||
| if (grad_info.grad_fn->key.lock().get() == key) { | |||
| return grad_info; | |||
| } | |||
| } | |||
| m_storage.emplace_back(); | |||
| return m_storage.back(); | |||
| GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr<GradKey> key) { | |||
| return grad_key_map.at(key); | |||
| } | |||
| GradInfo& GradInfoCollection::at(GradKey* key) { | |||
| _shrink(); | |||
| for (auto&& grad_info : m_storage) { | |||
| if (grad_info.grad_fn->key.lock().get() == key) { | |||
| return grad_info; | |||
| } | |||
| } | |||
| mgb_assert(false); | |||
| GradKeyWrapper::~GradKeyWrapper() { | |||
| grad_key_map.erase(m_key); | |||
| } | |||
| } // namespace mgb::imperative::python | |||
| @@ -12,166 +12,40 @@ | |||
| #pragma once | |||
| #include "./tensor.h" | |||
| #include "megbrain/imperative/ops/utility.h" | |||
| #include "megbrain/imperative/transformations/grad.h" | |||
| #include "megbrain/utils/small_vector.h" | |||
| #include <megbrain/utils/small_vector.h> | |||
| #include <memory> | |||
| #include <optional> | |||
| namespace mgb::imperative::python { | |||
| apply_result_t apply_grad(ApplyContext& ctx); | |||
| struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj { | |||
| std::string name; | |||
| bool active = true; | |||
| GradInfo::head_t free_vars_head; | |||
| std::vector<std::weak_ptr<GradFn>> tape; | |||
| int priority = 0; | |||
| ~GradKey(); | |||
| void attach(Tensor* tensor, pybind11::object callback); | |||
| void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | |||
| void cleanup(); | |||
| bool is_blocked() const { return priority < sm_min_priority; } | |||
| inline static bool allow_higher_order_directive = false; | |||
| private: | |||
| static int sm_min_priority; | |||
| }; | |||
| struct GradKeyWrapper { | |||
| struct GradKeyWrapper : NonCopyableObj { | |||
| using wrap_t = pyext17::wrap<GradKeyWrapper>; | |||
| static constexpr auto tp_name = pybind11::detail::_("GradKey"); | |||
| std::shared_ptr<GradKey> m_key; | |||
| std::shared_ptr<GradTransformation> m_transformation; | |||
| inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {} | |||
| GradKeyWrapper(); | |||
| PyObject* get_name(); | |||
| void set_name(pybind11::handle name); | |||
| PyObject* get_priority(); | |||
| void set_priority(pybind11::handle priority); | |||
| void attach(PyObject* const* args, size_t nargs); | |||
| void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | |||
| static void backward(GradKeyWrapper* self, pybind11::list, pybind11::list); | |||
| static pybind11::function get_backward_closure( | |||
| GradKeyWrapper* self, pybind11::list); | |||
| PyObject* is_attached_to(PyObject* const* args, size_t nargs); | |||
| void enter(); | |||
| void exit(); | |||
| void suppress(); | |||
| void resume(); | |||
| static GradKeyWrapper* get(std::shared_ptr<GradKey> key); | |||
| ~GradKeyWrapper(); | |||
| }; | |||
| struct BackwardContext { | |||
| PyTypeObject* pytype = nullptr; | |||
| auto wrap_tensor(std::shared_ptr<Tensor> t) { | |||
| if (pytype) { | |||
| return TensorWrapper::make(pytype, std::move(t)); | |||
| } | |||
| return TensorWrapper::make(std::move(t)); | |||
| } | |||
| auto wrap_tensor(Tensor* t) { return wrap_tensor(t->shared_from_this()); } | |||
| }; | |||
| struct CustomBackward { | |||
| using BackwardFn = | |||
| std::function<apply_result_t(BackwardContext&, Tensor* const*, size_t)>; | |||
| BackwardFn m_backward; | |||
| SmallVector<bool, 8> m_input_has_grad; | |||
| struct OutputAttr { | |||
| bool requires_grad = true, captured = true; | |||
| }; | |||
| SmallVector<OutputAttr> m_output_attrs; | |||
| public: | |||
| template <typename T, typename R> | |||
| void operator()(BackwardContext& ctx, T&& grads, R&& receiver) { | |||
| size_t nargs = grads.size(); | |||
| Tensor* args[nargs]; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| args[i] = grads[i]; | |||
| } | |||
| auto ret = m_backward(ctx, args, nargs); | |||
| for (size_t i = 0; i < ret.size(); ++i) { | |||
| if (auto&& t = ret[i]) { | |||
| receiver(i, std::move(t)); | |||
| } | |||
| } | |||
| } | |||
| bool input_has_grad(size_t i) { return m_input_has_grad[i]; } | |||
| bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; } | |||
| bool output_captured(size_t i) { return m_output_attrs[i].captured; } | |||
| class Maker { | |||
| bool output_size_set = false, input_has_grad_initialized = false; | |||
| CustomBackward& target; | |||
| ApplyContext& ctx; | |||
| void init_input_has_grad() { | |||
| if (!input_has_grad_initialized) { | |||
| input_has_grad_initialized = true; | |||
| target.m_input_has_grad.resize(ctx.nargs, true); | |||
| } | |||
| } | |||
| public: | |||
| Maker(CustomBackward& target_, ApplyContext& ctx_) | |||
| : target(target_), ctx(ctx_) {} | |||
| template <typename F> | |||
| Maker& backward(F&& f) { | |||
| mgb_assert(!target.m_backward); | |||
| target.m_backward = std::forward<F>(f); | |||
| return *this; | |||
| } | |||
| // mandatory | |||
| Maker& output_size(size_t sz) { | |||
| mgb_assert(!output_size_set); | |||
| output_size_set = true; | |||
| target.m_output_attrs.resize(sz); | |||
| return *this; | |||
| } | |||
| // optional, defaults to all true | |||
| Maker& input_has_grad(size_t i, bool v) { | |||
| init_input_has_grad(); | |||
| target.m_input_has_grad.at(i) = v; | |||
| return *this; | |||
| } | |||
| // optional, defaults to all true | |||
| Maker& output_requires_grad(size_t i, bool v) { | |||
| target.m_output_attrs.at(i).requires_grad = v; | |||
| return *this; | |||
| } | |||
| // optional, defaults to all true | |||
| Maker& output_captured(size_t i, bool v) { | |||
| target.m_output_attrs.at(i).captured = v; | |||
| return *this; | |||
| } | |||
| void finalize() { | |||
| mgb_assert(output_size_set); | |||
| init_input_has_grad(); | |||
| } | |||
| }; | |||
| Maker maker(ApplyContext& ctx) { return {*this, ctx}; } | |||
| }; | |||
| using GradRuleFn = std::function<std::optional<apply_result_t>( | |||
| ApplyContext&, CustomBackward::Maker&)>; | |||
| std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry(); | |||
| inline bool input_requires_grad(const ApplyContext& ctx, size_t i) { | |||
| return !ctx.args[i]->m_grad_info_dict.empty(); | |||
| } | |||
| struct GradRuleFallback : std::exception {}; | |||
| template <typename T> | |||
| bool register_grad_rule(Typeinfo* typeinfo, T&& rule) { | |||
| return grad_rule_registry().emplace(typeinfo, std::forward<T>(rule)).second; | |||
| } | |||
| } // namespace mgb::imperative::python | |||
| namespace pybind11::detail { | |||
| @@ -1,43 +0,0 @@ | |||
| /** | |||
| * \file imperative/python/src/grad_info.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include <memory> | |||
| #include "./intrusive_list.h" | |||
| namespace mgb::imperative::python { | |||
| struct GradKey; | |||
| struct GradFn; | |||
| struct GradSlot; | |||
| struct GradSlotPtr { | |||
| std::shared_ptr<GradFn> grad_fn; | |||
| size_t idx; | |||
| operator bool() const; | |||
| GradSlot* operator->(); | |||
| }; | |||
| struct GradInfo : GradSlotPtr, | |||
| intrusive_list::Node<GradInfo, intrusive_list::before_t> { | |||
| GradInfo() = default; | |||
| GradInfo(GradInfo&) = default; | |||
| GradInfo(GradInfo&&) = default; | |||
| GradInfo& operator=(GradInfo&) = default; | |||
| GradInfo& operator=(GradInfo&&) = default; | |||
| GradInfo(const GradInfo& rhs) : GradInfo(const_cast<GradInfo&>(rhs)) {} | |||
| GradInfo& operator=(const GradInfo& rhs) { | |||
| return *this = const_cast<GradInfo&>(rhs); | |||
| } | |||
| }; | |||
| } // namespace mgb::imperative::python | |||
| @@ -11,261 +11,334 @@ | |||
| #include "./grad.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/imperative/transformations/grad.h" | |||
| namespace mgb::imperative::python { | |||
| class CustomGradMaker { | |||
| bool output_size_set = false, input_has_grad_initialized = false; | |||
| CustomBackward& target; | |||
| size_t nr_inputs; | |||
| void init_input_has_grad() { | |||
| if (!input_has_grad_initialized) { | |||
| input_has_grad_initialized = true; | |||
| target.m_input_has_grad.resize(nr_inputs, true); | |||
| } | |||
| } | |||
| public: | |||
| CustomGradMaker(CustomBackward& target, size_t nr_inputs) | |||
| : target(target), nr_inputs(nr_inputs) {} | |||
| CustomGradMaker& backward(CustomBackward::BackwardFn f) { | |||
| mgb_assert(!target.m_backward); | |||
| target.m_backward = f; | |||
| return *this; | |||
| } | |||
| // mandatory | |||
| CustomGradMaker& output_size(size_t sz) { | |||
| mgb_assert(!output_size_set); | |||
| output_size_set = true; | |||
| target.m_output_attrs.resize(sz); | |||
| return *this; | |||
| } | |||
| // optional, defaults to all true | |||
| CustomGradMaker& input_has_grad(size_t i, bool v) { | |||
| init_input_has_grad(); | |||
| target.m_input_has_grad.at(i) = v; | |||
| return *this; | |||
| } | |||
| // optional, defaults to all true | |||
| CustomGradMaker& output_requires_grad(size_t i, bool v) { | |||
| target.m_output_attrs.at(i).requires_grad = v; | |||
| return *this; | |||
| } | |||
| // optional, defaults to all true | |||
| CustomGradMaker& output_captured(size_t i, bool v) { | |||
| target.m_output_attrs.at(i).captured = v; | |||
| return *this; | |||
| } | |||
| void finalize() { | |||
| mgb_assert(output_size_set); | |||
| init_input_has_grad(); | |||
| } | |||
| }; | |||
| namespace { | |||
| std::shared_ptr<Tensor> get_shape(Tensor* x) { | |||
| ValueRef get_shape(ValueRef x) { | |||
| static auto op = GetVarShape::make(); | |||
| return python::apply(op, x)[0]; | |||
| return imperative::apply(*op, x)[0]; | |||
| } | |||
| std::shared_ptr<Tensor> reduce_to(Tensor* x, Tensor* s) { | |||
| ValueRef reduce_to(ValueRef x, ValueRef s) { | |||
| static auto op = Reduce::make(); | |||
| return python::apply(op, x, s)[0]; | |||
| return imperative::apply(*op, x, s)[0]; | |||
| } | |||
| std::shared_ptr<Tensor> reshape_to(Tensor* x, Tensor* s) { | |||
| ValueRef reshape_to(ValueRef x, ValueRef s) { | |||
| static auto op = Reshape::make(); | |||
| return python::apply(op, x, s)[0]; | |||
| return imperative::apply(*op, x, s)[0]; | |||
| } | |||
| std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) { | |||
| ValueRef broadcast_to(ValueRef x, ValueRef s) { | |||
| static auto op = Broadcast::make(); | |||
| return python::apply(op, x, s)[0]; | |||
| return imperative::apply(*op, x, s)[0]; | |||
| } | |||
| std::shared_ptr<Tensor> make_empty_tensor(CompNode cn, Tensor* shape, DType dtype) { | |||
| HostTensorND scalar{cn, {{1}, dtype}}; | |||
| std::memset(scalar.raw_ptr(), 0, dtype.size()); | |||
| interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false); | |||
| auto&& t = std::make_shared<Tensor>(handle); | |||
| auto res = broadcast_to(t.get(), shape); | |||
| ValueRef make_empty_tensor( | |||
| CompNodeValue::ref_t device, ValueRef shape, DTypeValue::ref_t dtype) { | |||
| HostTensorStorage storage(*device); | |||
| storage.ensure_size(dtype->size()); | |||
| std::memset(storage.ptr(), 0, dtype->size()); | |||
| auto t = imperative::apply( | |||
| CreateTensor(CreateTensor::Unique, *device, *dtype, ValueShape()), | |||
| HostStorage::make(storage))[0]; | |||
| auto res = broadcast_to(t, shape); | |||
| return res; | |||
| } | |||
| std::optional<apply_result_t> elemwise_grad_rule( | |||
| ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
| auto& op = ctx.op->cast_final_safe<Elemwise>(); | |||
| if (op.mode == Elemwise::Mode::ADD) { | |||
| mgb_assert(ctx.nargs == 2); | |||
| std::array<std::shared_ptr<Tensor>, 2> input_shapes; | |||
| std::optional<std::vector<ValueRef>> elemwise_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| auto& elemwise = op.cast_final_safe<Elemwise>(); | |||
| if (elemwise.mode != Elemwise::Mode::ADD) { | |||
| return {}; | |||
| } | |||
| mgb_assert(inputs.size() == 2); | |||
| std::array<ValueRef, 2> input_shapes; | |||
| for (size_t i = 0; i < 2; ++i) { | |||
| if (inputs_require_grad[i]) { | |||
| input_shapes[i] = get_shape(inputs[i]); | |||
| } | |||
| } | |||
| auto maker = CustomGradMaker(backward, inputs.size()); | |||
| maker.output_size(1).output_captured(0, false); | |||
| maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| std::vector<ValueRef> ret(2); | |||
| if (!grad) { | |||
| return ret; | |||
| } | |||
| for (size_t i = 0; i < 2; ++i) { | |||
| if (input_requires_grad(ctx, i)) { | |||
| input_shapes[i] = get_shape(ctx.args[i]); | |||
| if (shapes[i]) { | |||
| ret[i] = reduce_to(grad, shapes[i]); | |||
| } | |||
| } | |||
| maker.output_size(1).output_captured(0, false); | |||
| maker.backward([shapes = std::move(input_shapes)]( | |||
| BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||
| mgb_assert(ngrads == 1); | |||
| Tensor* grad = grads[0]; | |||
| apply_result_t ret(2); | |||
| if (!grad) { | |||
| return ret; | |||
| } | |||
| for (size_t i = 0; i < 2; ++i) { | |||
| if (shapes[i]) { | |||
| ret[i] = reduce_to(grad, shapes[i].get()); | |||
| } | |||
| } | |||
| return ret; | |||
| }); | |||
| return apply(ctx); | |||
| } | |||
| return {}; | |||
| return ret; | |||
| }); | |||
| maker.finalize(); | |||
| return imperative::apply(ApplyOp(op), inputs); | |||
| } | |||
| std::optional<apply_result_t> reshape_grad_rule( | |||
| ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
| mgb_assert(ctx.nargs == 2); | |||
| std::array<std::shared_ptr<Tensor>, 2> input_shapes; | |||
| std::optional<std::vector<ValueRef>> reshape_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| mgb_assert(inputs.size() == 2); | |||
| std::array<ValueRef, 2> input_shapes; | |||
| for (size_t i = 0; i < 2; ++i) { | |||
| if (input_requires_grad(ctx, i)) { | |||
| input_shapes[i] = get_shape(ctx.args[i]); | |||
| if (inputs_require_grad[i]) { | |||
| input_shapes[i] = get_shape(inputs[i]); | |||
| } | |||
| } | |||
| auto maker = CustomGradMaker(backward, inputs.size()); | |||
| maker.output_size(1).output_captured(0, false); | |||
| maker.backward([shapes = std::move(input_shapes)]( | |||
| BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||
| mgb_assert(ngrads == 1); | |||
| Tensor* grad = grads[0]; | |||
| apply_result_t ret(2); | |||
| maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| std::vector<ValueRef> ret(2); | |||
| if (!grad) { | |||
| return ret; | |||
| } | |||
| for (size_t i = 0; i < 2; ++i) { | |||
| if (shapes[i]) { | |||
| ret[i] = reshape_to(grad, shapes[i].get()); | |||
| ret[i] = reshape_to(grad, shapes[i]); | |||
| } | |||
| } | |||
| return ret; | |||
| }); | |||
| return apply(ctx); | |||
| maker.finalize(); | |||
| return imperative::apply(ApplyOp(op), inputs); | |||
| } | |||
| std::optional<apply_result_t> subtensor_grad_rule( | |||
| ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
| auto&& op = ctx.op->cast_final_safe<Subtensor>(); | |||
| auto&& grad_op = SetSubtensor::make(op.items); | |||
| SmallVector<std::shared_ptr<Tensor>> inputs; | |||
| if (input_requires_grad(ctx, 0)) { | |||
| inputs.push_back(get_shape(ctx.args[0])); | |||
| for (size_t i = 1; i < ctx.nargs; ++i) { | |||
| inputs.push_back(ctx.args[i]->copy()); | |||
| std::optional<std::vector<ValueRef>> subtensor_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| auto&& subtensor = op.cast_final_safe<Subtensor>(); | |||
| auto&& grad_op = SetSubtensor::make(subtensor.items); | |||
| SmallVector<ValueRef> inputs2; | |||
| if (inputs_require_grad[0]) { | |||
| inputs2.push_back(get_shape(inputs[0])); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| inputs2.push_back(inputs[i]); | |||
| } | |||
| } | |||
| auto maker = CustomGradMaker(backward, inputs.size()); | |||
| maker.output_size(1).output_captured(0, false); | |||
| maker.backward([inputs = std::move(inputs), grad_op_ = std::move(grad_op)]( | |||
| BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||
| mgb_assert(ngrads == 1); | |||
| Tensor* grad = grads[0]; | |||
| apply_result_t ret(1); | |||
| maker.backward([inputs = std::move(inputs2), | |||
| grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| std::vector<ValueRef> ret(1); | |||
| if (grad && inputs[0]) { | |||
| SmallVector<Tensor*> args_(inputs.size() + 1); | |||
| auto&& zeros = make_empty_tensor( | |||
| grad->comp_node(), inputs[0].get(), grad->dtype()); | |||
| args_[0] = zeros.get(); | |||
| SmallVector<ValueRef> args_(inputs.size() + 1); | |||
| auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | |||
| args_[0] = zeros; | |||
| args_[1] = grad; | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| args_[i + 1] = inputs[i].get(); | |||
| args_[i + 1] = inputs[i]; | |||
| } | |||
| ret[0] = python::apply(grad_op_, args_)[0]; | |||
| ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0]; | |||
| } | |||
| return ret; | |||
| }); | |||
| return apply(ctx); | |||
| maker.finalize(); | |||
| return imperative::apply(ApplyOp(op), inputs); | |||
| } | |||
| std::optional<apply_result_t> indexingMultiAxisVec_grad_rule( | |||
| ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
| auto&& op = ctx.op->cast_final_safe<IndexingMultiAxisVec>(); | |||
| auto&& grad_op = IndexingSetMultiAxisVec::make(op.items); | |||
| SmallVector<std::shared_ptr<Tensor>> inputs; | |||
| if (input_requires_grad(ctx, 0)) { | |||
| inputs.push_back(get_shape(ctx.args[0])); | |||
| for (size_t i = 1; i < ctx.nargs; ++i) { | |||
| inputs.push_back(ctx.args[i]->copy()); | |||
| std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>(); | |||
| auto&& grad_op = IndexingSetMultiAxisVec::make(indexingMultiAxisVec.items); | |||
| SmallVector<ValueRef> inputs2; | |||
| if (inputs_require_grad[0]) { | |||
| inputs2.push_back(get_shape(inputs[0])); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| inputs2.push_back(inputs[i]); | |||
| } | |||
| } | |||
| auto maker = CustomGradMaker(backward, inputs.size()); | |||
| maker.output_size(1).output_captured(0, false); | |||
| maker.backward([inputs = std::move(inputs), grad_op_ = std::move(grad_op)]( | |||
| BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||
| mgb_assert(ngrads == 1); | |||
| Tensor* grad = grads[0]; | |||
| apply_result_t ret(1); | |||
| maker.backward([inputs = std::move(inputs2), | |||
| grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| std::vector<ValueRef> ret(1); | |||
| if (grad && inputs[0]) { | |||
| SmallVector<Tensor*> args_(inputs.size() + 1); | |||
| auto&& zeros = make_empty_tensor( | |||
| grad->comp_node(), inputs[0].get(), grad->dtype()); | |||
| args_[0] = zeros.get(); | |||
| SmallVector<ValueRef> args_(inputs.size() + 1); | |||
| auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | |||
| args_[0] = zeros; | |||
| args_[1] = grad; | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| args_[i + 1] = inputs[i].get(); | |||
| args_[i + 1] = inputs[i]; | |||
| } | |||
| ret[0] = python::apply(grad_op_, args_)[0]; | |||
| ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0]; | |||
| } | |||
| return ret; | |||
| }); | |||
| return apply(ctx); | |||
| maker.finalize(); | |||
| return imperative::apply(ApplyOp(op), inputs); | |||
| } | |||
| std::optional<apply_result_t> reduce_grad_rule( | |||
| ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
| auto& op = ctx.op->cast_final_safe<Reduce>(); | |||
| if (op.mode == Reduce::Mode::SUM) { | |||
| if (ctx.nargs != 1) { | |||
| return {}; | |||
| } | |||
| std::array<std::shared_ptr<Tensor>, 1> input_shapes; | |||
| if (input_requires_grad(ctx, 0)) { | |||
| input_shapes[0] = get_shape(ctx.args[0]); | |||
| } | |||
| maker.output_size(1).output_captured(0, false); | |||
| maker.backward([shapes = std::move(input_shapes)]( | |||
| BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||
| mgb_assert(ngrads == 1); | |||
| Tensor* grad = grads[0]; | |||
| apply_result_t ret(1); | |||
| if (grad && shapes[0]) { | |||
| ret[0] = broadcast_to(grad, shapes[0].get()); | |||
| } | |||
| return ret; | |||
| }); | |||
| return apply(ctx); | |||
| std::optional<std::vector<ValueRef>> reduce_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| auto& reduce = op.cast_final_safe<Reduce>(); | |||
| if (reduce.mode != Reduce::Mode::SUM) { | |||
| return {}; | |||
| } | |||
| if (inputs.size() != 1) { | |||
| return {}; | |||
| } | |||
| std::array<ValueRef, 1> input_shapes; | |||
| if (inputs_require_grad[0]) { | |||
| input_shapes[0] = get_shape(inputs[0]); | |||
| } | |||
| return {}; | |||
| auto maker = CustomGradMaker(backward, inputs.size()); | |||
| maker.output_size(1).output_captured(0, false); | |||
| maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| std::vector<ValueRef> ret(1); | |||
| if (grad && shapes[0]) { | |||
| ret[0] = broadcast_to(grad, shapes[0]); | |||
| } | |||
| return ret; | |||
| }); | |||
| maker.finalize(); | |||
| return imperative::apply(ApplyOp(op), inputs); | |||
| } | |||
| std::optional<apply_result_t> addAxis_grad_rule( | |||
| ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
| auto&& op = ctx.op->cast_final_safe<AddAxis>(); | |||
| mgb_assert(ctx.nargs == 1); | |||
| bool flag = input_requires_grad(ctx, 0); | |||
| auto&& grad_op = RemoveAxis::make(op.axis); | |||
| std::optional<std::vector<ValueRef>> addAxis_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| auto&& addAxis = op.cast_final_safe<AddAxis>(); | |||
| mgb_assert(inputs.size() == 1); | |||
| bool flag = inputs_require_grad[0]; | |||
| auto&& grad_op = RemoveAxis::make(addAxis.axis); | |||
| std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>()); | |||
| auto maker = CustomGradMaker(backward, inputs.size()); | |||
| maker.output_size(1).output_captured(0, false); | |||
| maker.backward([grad_op_ = std::move(grad_op), flag_ = flag]( | |||
| BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||
| mgb_assert(ngrads == 1); | |||
| Tensor* grad = grads[0]; | |||
| apply_result_t ret(1); | |||
| maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| std::vector<ValueRef> ret(1); | |||
| if (grad && flag_) { | |||
| ret[0] = python::apply(grad_op_, grad)[0]; | |||
| ret[0] = imperative::apply(*grad_op_, grad)[0]; | |||
| } | |||
| return ret; | |||
| }); | |||
| return apply(ctx); | |||
| maker.finalize(); | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| std::optional<apply_result_t> removeAxis_grad_rule( | |||
| ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
| auto&& op = ctx.op->cast_final_safe<RemoveAxis>(); | |||
| mgb_assert(ctx.nargs == 1); | |||
| bool flag = input_requires_grad(ctx, 0); | |||
| auto&& grad_op = AddAxis::make(op.axis); | |||
| std::optional<std::vector<ValueRef>> removeAxis_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| auto&& removeAxis = op.cast_final_safe<RemoveAxis>(); | |||
| mgb_assert(inputs.size() == 1); | |||
| bool flag = inputs_require_grad[0]; | |||
| auto&& grad_op = AddAxis::make(removeAxis.axis); | |||
| std::sort(grad_op->axis.begin(), grad_op->axis.end()); | |||
| auto maker = CustomGradMaker(backward, inputs.size()); | |||
| maker.output_size(1).output_captured(0, false); | |||
| maker.backward([grad_op_ = std::move(grad_op), flag_ = flag]( | |||
| BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||
| mgb_assert(ngrads == 1); | |||
| Tensor* grad = grads[0]; | |||
| apply_result_t ret(1); | |||
| maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| std::vector<ValueRef> ret(1); | |||
| if (grad && flag_) { | |||
| ret[0] = python::apply(grad_op_, grad)[0]; | |||
| ret[0] = imperative::apply(*grad_op_, grad)[0]; | |||
| } | |||
| return ret; | |||
| }); | |||
| return apply(ctx); | |||
| maker.finalize(); | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| std::optional<apply_result_t> fastpathcopy_grad_rule( | |||
| ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
| mgb_assert(ctx.nargs == 1); | |||
| std::optional<std::vector<ValueRef>> fastpathcopy_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| mgb_assert(inputs.size() == 1); | |||
| auto maker = CustomGradMaker(backward, inputs.size()); | |||
| maker.output_size(1).output_captured(0, false); | |||
| maker.backward([](BackwardContext&, Tensor* const* grads, size_t ngrads) { | |||
| mgb_assert(ngrads == 1); | |||
| Tensor* grad = grads[0]; | |||
| apply_result_t ret(1); | |||
| maker.backward([](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| std::vector<ValueRef> ret(1); | |||
| if (grad) { | |||
| ret[0] = grad->shared_from_this(); | |||
| ret[0] = grad; | |||
| } | |||
| return ret; | |||
| }); | |||
| return apply(ctx); | |||
| maker.finalize(); | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| struct Init { | |||
| Init() { | |||
| auto& reg = grad_rule_registry(); | |||
| reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule); | |||
| reg.emplace(Reshape::typeinfo(), reshape_grad_rule); | |||
| reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule); | |||
| reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); | |||
| reg.emplace(Reduce::typeinfo(), reduce_grad_rule); | |||
| reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule); | |||
| reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule); | |||
| reg.emplace(FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | |||
| CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule); | |||
| CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule); | |||
| CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule); | |||
| CustomBackward::register_grad_rule( | |||
| IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); | |||
| CustomBackward::register_grad_rule(Reduce::typeinfo(), reduce_grad_rule); | |||
| CustomBackward::register_grad_rule(AddAxis::typeinfo(), addAxis_grad_rule); | |||
| CustomBackward::register_grad_rule( | |||
| RemoveAxis::typeinfo(), removeAxis_grad_rule); | |||
| CustomBackward::register_grad_rule( | |||
| FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | |||
| } | |||
| } _; | |||
| @@ -1,245 +0,0 @@ | |||
| /** | |||
| * \file imperative/python/src/intrusive_list.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/utils/metahelper.h" | |||
| namespace mgb::imperative::python::intrusive_list { | |||
| // copy policy | |||
| struct after_t {}; | |||
| struct before_t {}; | |||
| struct disable_t {}; | |||
| template <typename T> | |||
| struct Tail; | |||
| // invariant: next->prev == this | |||
| template <typename T> | |||
| struct Head { | |||
| Tail<T>* next; | |||
| Head(Tail<T>* node = nullptr) : next(node) {} | |||
| Head(const Head<T>&) = delete; | |||
| Head<T>& operator=(const Head<T>&) = delete; | |||
| Head(Head<T>&& rhs) : next(rhs.next) { | |||
| rhs.next = nullptr; | |||
| if (next) { | |||
| next->prev = this; | |||
| } | |||
| } | |||
| Head<T>& operator=(Head<T>&& rhs) { | |||
| mgb_assert(!next); | |||
| next = rhs.next; | |||
| rhs.next = nullptr; | |||
| if (next) { | |||
| next->prev = this; | |||
| } | |||
| return *this; | |||
| } | |||
| ~Head() { | |||
| if (next) { | |||
| next->prev = nullptr; | |||
| } | |||
| } | |||
| }; | |||
| // invariant: prev->next == this | |||
| template <typename T> | |||
| struct Tail { | |||
| Head<T>* prev; | |||
| Tail(Head<T>* node = nullptr) : prev(node) {} | |||
| Tail(const Tail<T>&) = delete; | |||
| Tail<T>& operator=(const Tail<T>&) = delete; | |||
| Tail(Tail<T>&& rhs) : prev(rhs.prev) { | |||
| rhs.prev = nullptr; | |||
| if (prev) { | |||
| prev->next = this; | |||
| } | |||
| } | |||
| Tail<T>& operator=(Tail<T>&& rhs) { | |||
| mgb_assert(!prev); | |||
| prev = rhs.prev; | |||
| rhs.prev = nullptr; | |||
| if (prev) { | |||
| prev->next = this; | |||
| } | |||
| return *this; | |||
| } | |||
| ~Tail() { | |||
| if (prev) { | |||
| prev->next = nullptr; | |||
| } | |||
| } | |||
| }; | |||
| template <typename T, typename policy> | |||
| struct Node; | |||
| template <typename T> | |||
| class Iterator { | |||
| T* ptr; | |||
| void inc() { ptr = static_cast<T*>(ptr->Head<T>::next); } | |||
| void dec() { ptr = static_cast<T*>(ptr->Head<T>::prev); } | |||
| public: | |||
| Iterator(Head<T>& head) : ptr(static_cast<T*>(head.next)) {} | |||
| Iterator(Tail<T>& tail) : ptr(static_cast<T*>(tail.prev)) {} | |||
| template <typename policy> | |||
| Iterator(Node<T, policy>& node) : ptr(static_cast<T*>(&node)) {} | |||
| T& operator*() { return *static_cast<T*>(ptr); } | |||
| T* operator->() { return static_cast<T*>(ptr); } | |||
| operator bool() { return ptr; } | |||
| bool operator==(const Iterator<T>& rhs) { return ptr == rhs.ptr; } | |||
| Iterator& operator++() { | |||
| inc(); | |||
| return *this; | |||
| } | |||
| Iterator& operator--() { | |||
| dec(); | |||
| return *this; | |||
| } | |||
| Iterator operator++(int) { | |||
| auto ret = *this; | |||
| inc(); | |||
| return ret; | |||
| } | |||
| Iterator operator--(int) { | |||
| auto ret = *this; | |||
| dec(); | |||
| return ret; | |||
| } | |||
| }; | |||
| // Node in a doubly linked list. Unlike std::list, nodes are not owned by a container. | |||
| // Instead, nodes may join or leave a list freely. | |||
| // NOTE: Derived classes have to explicitly declare copy / assignment as default, | |||
| // otherwise the compiler generated version would use the const T& signature, | |||
| // which is deleted. | |||
| template <typename T = void, typename policy = disable_t> | |||
| struct Node : Tail<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>>, | |||
| Head<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>> { | |||
| private: | |||
| using this_t = Node<T, policy>; | |||
| using U = std::conditional_t<std::is_same_v<T, void>, this_t, T>; | |||
| public: | |||
| using head_t = Head<U>; | |||
| using tail_t = Tail<U>; | |||
| using head_t::next; | |||
| using tail_t::prev; | |||
| Node() = default; | |||
| Node(const this_t&) = delete; | |||
| this_t& operator=(const this_t&) = delete; | |||
| //! constructed node is inserted after the input node | |||
| Node(after_t, head_t& node) : tail_t(&node), head_t(node.next) { | |||
| node.next = this; | |||
| if (next) { | |||
| next->prev = this; | |||
| } | |||
| } | |||
| //! constructed node is inserted before the input node | |||
| Node(before_t, tail_t& node) : head_t(&node), tail_t(node.prev) { | |||
| node.prev = this; | |||
| if (prev) { | |||
| prev->next = this; | |||
| } | |||
| } | |||
| Node(this_t&& rhs) : tail_t(rhs.prev), head_t(rhs.next) { | |||
| rhs.prev = nullptr; | |||
| rhs.next = nullptr; | |||
| if (prev) { | |||
| prev->next = this; | |||
| } | |||
| if (next) { | |||
| next->prev = this; | |||
| } | |||
| } | |||
| Node& operator=(this_t&& rhs) { | |||
| unlink(); | |||
| prev = rhs.prev; | |||
| next = rhs.next; | |||
| rhs.prev = nullptr; | |||
| rhs.next = nullptr; | |||
| if (prev) { | |||
| prev->next = this; | |||
| } | |||
| if (next) { | |||
| next->prev = this; | |||
| } | |||
| return *this; | |||
| } | |||
| template < | |||
| typename p = policy, | |||
| typename = std::enable_if_t< | |||
| std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>> | |||
| Node(this_t& rhs) : Node(policy{}, rhs) {} | |||
| template < | |||
| typename p = policy, | |||
| typename = std::enable_if_t< | |||
| std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>> | |||
| this_t& operator=(this_t& rhs) { | |||
| insert(policy{}, rhs); | |||
| return *this; | |||
| } | |||
| void unlink() { | |||
| if (prev) { | |||
| prev->next = next; | |||
| } | |||
| if (next) { | |||
| next->prev = prev; | |||
| } | |||
| prev = nullptr; | |||
| next = nullptr; | |||
| } | |||
| //! this node is unlinked from its list and inserted after the input node | |||
| void insert(after_t, head_t& node) { | |||
| unlink(); | |||
| prev = &node; | |||
| next = node.next; | |||
| node.next = this; | |||
| if (next) { | |||
| next->prev = this; | |||
| } | |||
| } | |||
| //! this node is unlinked from its list and inserted before the input node | |||
| void insert(before_t, tail_t& node) { | |||
| unlink(); | |||
| next = &node; | |||
| prev = node.prev; | |||
| node.prev = this; | |||
| if (prev) { | |||
| prev->next = this; | |||
| } | |||
| } | |||
| void insert_before(tail_t& node) { insert(before_t{}, node); } | |||
| void insert_after(head_t& node) { insert(after_t{}, node); } | |||
| ~Node() { unlink(); } | |||
| }; | |||
| } // namespace mgb::imperative::python::intrusive_list | |||
| @@ -1,42 +0,0 @@ | |||
| /** | |||
| * \file imperative/python/src/module_trace.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "./module_trace.h" | |||
| #include "./helper.h" // include op pybind11 caster | |||
| namespace py = pybind11; | |||
| namespace mgb::imperative::python { | |||
| apply_result_t apply_module_trace(ApplyContext& ctx) { | |||
| apply_result_t outputs; | |||
| auto args = py::tuple(ctx.nargs + 1); | |||
| args[0] = py::cast(ctx.op); | |||
| for (size_t i = 0; i < ctx.nargs; i++) { | |||
| args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); | |||
| } | |||
| auto pyout = PyObject_Call(cpp_apply_module_trace, args.ptr(), nullptr); | |||
| if (!pyout) | |||
| throw py::error_already_set(); | |||
| auto ret = py::reinterpret_steal<py::object>(pyout); | |||
| // assumption: python function always returns PyList | |||
| auto tup = py::reinterpret_borrow<py::list>(ret); | |||
| for (auto i = 0; i < tup.size(); i++) { | |||
| auto tw = TensorWrapper::try_cast(tup[i].ptr()); | |||
| outputs.emplace_back(tw->m_tensor); | |||
| } | |||
| return outputs; | |||
| } | |||
| } // namespace mgb::imperative::python | |||
| @@ -11,10 +11,50 @@ | |||
| #pragma once | |||
| #include "megbrain/imperative/transformations/trace.h" | |||
| #include "megbrain/imperative/utils/map.h" | |||
| #include "./tensor.h" | |||
| namespace mgb::imperative::python { | |||
| apply_result_t apply_module_trace(ApplyContext& ctx); | |||
| namespace py = pybind11; | |||
| class ModuleTraceTransformation final : public Transformation { | |||
| private: | |||
| py::function m_hook_fn; | |||
| int m_enabled = 0; | |||
| std::vector<ValueRef> apply_module_trace_hook( | |||
| const OpDef& op, Span<ValueRef> input_values) { | |||
| py::list input_tws; | |||
| for (auto&& input_value : input_values) { | |||
| input_tws.append(TensorWrapper::make(py_tensor_type, input_value)); | |||
| } | |||
| py::list output_tws = m_hook_fn(py::cast(op.shared_from_this()), *input_tws); | |||
| std::vector<ValueRef> outputs; | |||
| for (auto&& output_tw : output_tws) { | |||
| outputs.push_back( | |||
| TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data()); | |||
| } | |||
| return outputs; | |||
| } | |||
| public: | |||
| ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {} | |||
| std::vector<ValueRef> apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) override { | |||
| if (op.is<ApplyOp>() && m_enabled > 0) { | |||
| auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs); | |||
| return outputs; | |||
| } else { | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| } | |||
| ValueRef unwrap(ValueRef value) override { return value; } | |||
| std::string name() const override { return "ModuleTraceTransformation"; } | |||
| }; | |||
| } // namespace mgb::imperative::python | |||
| @@ -185,7 +185,8 @@ int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) { | |||
| } | |||
| PyGetSetDef PyOp(OpDef)::py_getsetters[] = { | |||
| {const_cast<char*>("scope"), py_get_scope, py_set_scope, "scope", NULL}, | |||
| {const_cast<char*>("scope"), py_get_scope, py_set_scope, | |||
| const_cast<char*>("scope"), NULL}, | |||
| {NULL}}; | |||
| Py_hash_t PyOp(OpDef)::tp_hash(PyObject* obj) { | |||
| @@ -556,12 +557,6 @@ void init_ops(py::module m) { | |||
| m.def( | |||
| "delete_rng_handle", | |||
| [](size_t handle) { | |||
| // RNG op might execute after handle released due to async dispatch, so | |||
| // we need sync before delete a handle to avoid memory leak or | |||
| // use-after-free | |||
| if (python::interpreter_for_py->check_available()) { | |||
| python::interpreter_for_py->sync(); | |||
| } | |||
| mgb::CompNode::sync_all(); | |||
| py_task_q.wait_all_task_finish(); | |||
| rng::delete_handle(handle); | |||
| @@ -20,6 +20,8 @@ | |||
| #include "pybind11/pybind11.h" | |||
| #include "./pyext17.h" | |||
| #include "megbrain/imperative/dispatch.h" | |||
| #include "megbrain/imperative/utils/span.h" | |||
| namespace mgb::imperative::python { | |||
| @@ -32,126 +34,67 @@ struct ObjectPtr : B { | |||
| } // namespace mgb::imperative::python | |||
| #include "./grad_info.h" // for struct GradInfo | |||
| #include "./trace_info.h" // for struct TraceInfo | |||
| namespace mgb::imperative::python { | |||
| struct GradKey; | |||
| extern interpreter::Interpreter::Channel* interpreter_for_py; | |||
| extern PyTypeObject* py_tensor_type; | |||
| class SharedHandle { | |||
| using Handle = interpreter::Interpreter::Handle; | |||
| static_assert(std::is_pointer_v<Handle>); | |||
| std::shared_ptr<std::remove_pointer_t<Handle>> holder; | |||
| public: | |||
| inline explicit SharedHandle(Handle handle) | |||
| : holder(handle, [](auto* h) { | |||
| if (h) { | |||
| interpreter_for_py->del(h); | |||
| } | |||
| }) {} | |||
| SharedHandle(const SharedHandle&) = default; | |||
| SharedHandle& operator=(const SharedHandle&) = default; | |||
| SharedHandle(SharedHandle&&) = default; | |||
| SharedHandle& operator=(SharedHandle&&) = default; | |||
| inline Handle get() { return holder.get(); } | |||
| }; | |||
| // impl in grad.cpp | |||
| class GradInfoCollection { | |||
| struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||
| private: | |||
| SmallVector<GradInfo> m_storage; | |||
| protected: | |||
| void _shrink(); | |||
| std::string m_name; | |||
| ValueRef m_data; | |||
| public: | |||
| bool contains(GradKey* key); | |||
| GradInfo& operator[](GradKey* key); | |||
| GradInfo& at(GradKey* key); | |||
| bool empty() { | |||
| _shrink(); | |||
| return m_storage.empty(); | |||
| } | |||
| auto begin() { | |||
| _shrink(); | |||
| return m_storage.begin(); | |||
| } | |||
| auto end() { | |||
| _shrink(); | |||
| return m_storage.end(); | |||
| } | |||
| size_t count(GradKey* key) { return contains(key) ? 1 : 0; } | |||
| }; | |||
| struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||
| using flags_t = uint64_t; | |||
| struct Flags { | |||
| static constexpr flags_t SCALAR = 1; | |||
| static constexpr flags_t GRAD = 1 << 1; | |||
| static constexpr flags_t TRACE = 1 << 2; | |||
| static constexpr flags_t MODULE_TRACE = 1 << 3; | |||
| }; | |||
| flags_t m_flags = 0; | |||
| GradInfoCollection m_grad_info_dict; | |||
| TraceInfo m_trace_info; | |||
| SharedHandle m_handle; | |||
| std::string user_custom_name; | |||
| std::string automatic_name; | |||
| cg::VarNode* m_var; | |||
| pybind11::object m_module_trace_info; | |||
| using Handle = interpreter::Interpreter::Handle; | |||
| inline Tensor() : m_handle(nullptr), m_var(nullptr) {} | |||
| inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {} | |||
| inline explicit Tensor(SharedHandle handle) | |||
| : m_handle(std::move(handle)), m_var(nullptr) {} | |||
| inline explicit Tensor(cg::VarNode* var) : m_handle(nullptr), m_var(var) {} | |||
| inline explicit Tensor(ValueRef data) : m_data{data} {} | |||
| ~Tensor() = default; | |||
| inline std::shared_ptr<Tensor> copy() { | |||
| auto ret = std::make_shared<Tensor>(m_handle); | |||
| ret->m_flags = m_flags; | |||
| ret->m_grad_info_dict = m_grad_info_dict; | |||
| ret->m_trace_info = m_trace_info; | |||
| ret->m_var = m_var; | |||
| auto ret = std::make_shared<Tensor>(m_data.unwrap()); | |||
| ret->m_name = m_name; | |||
| return ret; | |||
| } | |||
| inline DType dtype() { | |||
| if (m_var) { | |||
| return m_var->dtype(); | |||
| inline DType dtype() { return *data().dtype(); } | |||
| inline CompNode comp_node() { return *data().device(); } | |||
| inline std::optional<ValueShape> shape() { | |||
| auto shape = data().shape(); | |||
| if (!shape) { | |||
| return {}; | |||
| } | |||
| return interpreter_for_py->get_dtype(m_handle.get()); | |||
| return *shape; | |||
| } | |||
| inline CompNode comp_node() { | |||
| if (m_var) { | |||
| return m_var->comp_node(); | |||
| inline HostValue::ref_t numpy() { return data().numpy(); } | |||
| inline void reset(ValueRef value) { | |||
| m_data = value; | |||
| if (!m_name.empty()) { | |||
| set_name(m_name); | |||
| } | |||
| return interpreter_for_py->get_device(m_handle.get()); | |||
| } | |||
| inline TensorShape shape() { | |||
| if (m_var) { | |||
| return m_var->shape(); | |||
| inline ValueRef data() { return m_data.unwrap(); } | |||
| bool is_scalar() { return data().is_scalar(); } | |||
| inline std::string name() { return m_name; } | |||
| inline void set_name(std::string name) { | |||
| m_name = name; | |||
| if (!name.empty()) { | |||
| auto output = imperative::apply(RenameValue(name), m_data)[0]; | |||
| m_data = output; | |||
| } | |||
| return interpreter_for_py->get_shape(m_handle.get()); | |||
| } | |||
| }; | |||
| struct TensorWrapper { | |||
| public: | |||
| std::shared_ptr<Tensor> m_tensor; | |||
| inline TensorWrapper(std::shared_ptr<Tensor> tensor = {}) | |||
| : m_tensor(std::move(tensor)) {} | |||
| : m_tensor(std::move(tensor)) { | |||
| mgb_assert(tensor, "empty storage"); | |||
| } | |||
| inline TensorWrapper(ValueRef value) : m_tensor(std::make_shared<Tensor>(value)) {} | |||
| TensorWrapper(PyObject* args, PyObject* kwargs); | |||
| ~TensorWrapper() = default; | |||
| @@ -191,33 +134,17 @@ struct TensorWrapper { | |||
| void reset(PyObject*); | |||
| PyObject* detach(); | |||
| PyObject* isscalar(); | |||
| void setscalar(); | |||
| void unsetscalar(); | |||
| PyObject* _dev_tensor(); | |||
| void _drop(); | |||
| PyObject* varnode(); | |||
| void reset_varnode(); | |||
| PyObject* handle(); | |||
| void set_handle(PyObject*); | |||
| PyObject* mixin_handle(); | |||
| PyObject* recording(); | |||
| PyObject* copied(); | |||
| void set_mixin_handle(PyObject*); | |||
| void set_recording(PyObject*); | |||
| PyObject* compiled_info(); | |||
| void set_compiled_info(PyObject*); | |||
| PyObject* trace_mixin_info(); | |||
| void set_trace_mixin_info(PyObject*); | |||
| PyObject* module_trace_info(); | |||
| void set_module_trace_info(PyObject*); | |||
| PyObject* user_custom_name(); | |||
| void set_user_custom_name(PyObject*); | |||
| PyObject* automatic_name(); | |||
| void set_automatic_name(PyObject*); | |||
| void _set_name(PyObject*); | |||
| PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; | |||
| PyObject* _detail(); | |||
| void _watch(); | |||
| }; | |||
| struct PySymbolVar { | |||
| @@ -230,113 +157,8 @@ struct PySymbolVar { | |||
| PyObject* py_apply( | |||
| PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */); | |||
| struct ApplyContext { | |||
| static Tensor::flags_t global_disable; | |||
| static Tensor::flags_t global_enable; | |||
| Tensor::flags_t flags = 0; | |||
| std::shared_ptr<OpDef> op; | |||
| Tensor* const* args; | |||
| size_t nargs; | |||
| PyTypeObject* pytype = nullptr; | |||
| bool backward = false; | |||
| class scoped_disable : NonCopyableObj { | |||
| Tensor::flags_t saved_flags; | |||
| public: | |||
| scoped_disable(Tensor::flags_t flags) | |||
| : saved_flags(ApplyContext::global_disable) { | |||
| ApplyContext::global_disable |= flags; | |||
| } | |||
| ~scoped_disable() { ApplyContext::global_disable = saved_flags; } | |||
| }; | |||
| }; | |||
| using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>; | |||
| apply_result_t apply(ApplyContext& ctx); | |||
| template <typename T> | |||
| decltype(auto) resolve_arrow(T&& p) { | |||
| if constexpr (std::is_pointer_v<std::remove_reference_t<T>>) { | |||
| auto* ret = p; | |||
| return ret; | |||
| } else { | |||
| auto probe = [](auto&& p) -> decltype(p.operator->()) {}; | |||
| if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) { | |||
| return resolve_arrow(p.operator->()); | |||
| } else { | |||
| return std::forward<T>(p); | |||
| } | |||
| } | |||
| } | |||
| template <typename... Args> | |||
| constexpr bool is_all_tensor_ptr = | |||
| (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>); | |||
| template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0> | |||
| apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) { | |||
| ApplyContext ctx; | |||
| Tensor* arg_arr[] = {resolve_arrow(args)...}; | |||
| ctx.flags = (0 | ... | args->m_flags); | |||
| ctx.args = arg_arr; | |||
| ctx.nargs = sizeof...(args); | |||
| ctx.op = std::move(op); | |||
| return apply(ctx); | |||
| } | |||
| inline auto apply(std::shared_ptr<OpDef> op, Tensor* const* args, size_t nargs) { | |||
| ApplyContext ctx; | |||
| ctx.op = std::move(op); | |||
| ctx.nargs = nargs; | |||
| ctx.args = args; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| ctx.flags |= args[i]->m_flags; | |||
| } | |||
| return apply(ctx); | |||
| } | |||
| template <typename T> | |||
| auto apply(std::shared_ptr<OpDef> op, T&& tensors) -> std::enable_if_t< | |||
| std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>, apply_result_t> { | |||
| size_t nargs = tensors.size(); | |||
| Tensor* args[nargs]; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| args[i] = resolve_arrow(tensors[i]); | |||
| } | |||
| return apply(op, args, nargs); | |||
| } | |||
| std::shared_ptr<Tensor> make_const(imperative::TensorPtr value); | |||
| inline auto apply(Subgraph graph, Tensor* const* args, size_t nargs) { | |||
| SmallVector<std::shared_ptr<Tensor>> inputs; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| inputs.push_back(args[i]->shared_from_this()); | |||
| } | |||
| auto apply_functor = [](std::shared_ptr<OpDef> op, | |||
| SmallVector<std::shared_ptr<Tensor>> inputs, | |||
| size_t) { return apply(op, std::move(inputs)); }; | |||
| return graph.apply(inputs, apply_functor, &make_const); | |||
| } | |||
| template <typename T> | |||
| auto apply(Subgraph graph, T&& tensors) -> std::enable_if_t< | |||
| std::is_same_v<std::decay_t<decltype(tensors[0])>, Tensor*>, apply_result_t> { | |||
| size_t nargs = tensors.size(); | |||
| Tensor* args[nargs]; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| args[i] = resolve_arrow(tensors[i]); | |||
| } | |||
| return apply(graph, args, nargs); | |||
| } | |||
| void init_tensor(pybind11::module); | |||
| extern PyObject* cpp_apply_with_tracing; | |||
| extern PyObject* cpp_apply_backward_varnode; | |||
| extern PyObject* cpp_apply_module_trace; | |||
| } // namespace mgb::imperative::python | |||
| @@ -1,63 +0,0 @@ | |||
| /** | |||
| * \file imperative/python/src/trace.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "./trace.h" | |||
| #include "./helper.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| namespace py = pybind11; | |||
| namespace mgb::imperative::python { | |||
| apply_result_t apply_trace(ApplyContext& ctx) { | |||
| apply_result_t outputs; | |||
| if (ctx.backward) { | |||
| // reach here when compiled=True | |||
| auto args = py::tuple(ctx.nargs + 1); | |||
| args[0] = py::cast(ctx.op); | |||
| for (size_t i = 0; i < ctx.nargs; i++) { | |||
| args[i + 1] = py::cast(ctx.args[i]->m_var); | |||
| } | |||
| py::object pyout = py::reinterpret_steal<py::object>( | |||
| PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr)); | |||
| if (!pyout) | |||
| throw py::error_already_set(); | |||
| // assumption: python function always returns PyList | |||
| auto tup = py::reinterpret_borrow<py::list>(pyout); | |||
| for (size_t i = 0; i < tup.size(); i++) { | |||
| auto pitem = tup[i].cast<cg::VarNode*>(); | |||
| outputs.emplace_back(std::make_shared<Tensor>(pitem)); | |||
| } | |||
| return outputs; | |||
| } | |||
| auto args = py::tuple(ctx.nargs + 1); | |||
| args[0] = py::cast(ctx.op); | |||
| for (size_t i = 0; i < ctx.nargs; i++) { | |||
| args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); | |||
| } | |||
| auto pyout = PyObject_Call(cpp_apply_with_tracing, args.ptr(), nullptr); | |||
| if (!pyout) | |||
| throw py::error_already_set(); | |||
| // assumption: python function always returns PyList | |||
| auto tup = py::reinterpret_steal<py::list>(pyout); | |||
| for (size_t i = 0; i < tup.size(); i++) { | |||
| auto tw = TensorWrapper::try_cast(tup[i].ptr()); | |||
| outputs.emplace_back(tw->m_tensor); | |||
| } | |||
| return outputs; | |||
| } | |||
| } // namespace mgb::imperative::python | |||
| @@ -1,28 +0,0 @@ | |||
| /** | |||
| * \file imperative/python/src/trace.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include <stdexcept> | |||
| #include "./tensor.h" | |||
| namespace mgb::imperative::python { | |||
| class TraceReadError : public std::exception { | |||
| public: | |||
| explicit TraceReadError(const char* m) : message{m} {} | |||
| const char* what() const noexcept override { return message.c_str(); } | |||
| private: | |||
| std::string message = ""; | |||
| }; | |||
| apply_result_t apply_trace(ApplyContext& ctx); | |||
| } // namespace mgb::imperative::python | |||
| @@ -1,49 +0,0 @@ | |||
| /** | |||
| * \file imperative/python/src/trace_info.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "Python.h" | |||
| #include "inttypes.h" | |||
| namespace mgb::imperative::python { | |||
| struct TraceInfo { | |||
| int64_t mixin_handle = -1; | |||
| bool recording = false; | |||
| // refer to CompiledTensorProxy in tracing.py, works from second trace step | |||
| PyObject* compiled_info = nullptr; | |||
| // refer to TensorInfo in tracing.py, only works in first trace step | |||
| PyObject* trace_mixin_info = nullptr; | |||
| TraceInfo() = default; | |||
| TraceInfo& operator=(const TraceInfo& that) { | |||
| mixin_handle = that.mixin_handle; | |||
| recording = that.recording; | |||
| trace_mixin_info = that.trace_mixin_info; | |||
| Py_XINCREF(trace_mixin_info); | |||
| compiled_info = that.compiled_info; | |||
| Py_XINCREF(compiled_info); | |||
| return *this; | |||
| } | |||
| ~TraceInfo() { | |||
| Py_XDECREF(trace_mixin_info); | |||
| Py_XDECREF(compiled_info); | |||
| } | |||
| private: | |||
| TraceInfo(const TraceInfo& that) = default; | |||
| }; | |||
| } // namespace mgb::imperative::python | |||
| @@ -16,10 +16,6 @@ import megengine.module | |||
| from megengine import Parameter | |||
| from megengine.core._imperative_rt.core2 import sync | |||
| from megengine.device import get_device_count | |||
| from megengine.experimental.autograd import ( | |||
| disable_higher_order_directive, | |||
| enable_higher_order_directive, | |||
| ) | |||
| from megengine.jit import trace as _trace | |||
| from megengine.module import Linear, Module | |||
| @@ -45,13 +41,3 @@ def skip_distributed(request): | |||
| platform.system() | |||
| ) | |||
| ) | |||
| @pytest.fixture(autouse=True) | |||
| def resolve_require_higher_order_directive(request): | |||
| marker = request.node.get_closest_marker("require_higher_order_directive") | |||
| if marker: | |||
| enable_higher_order_directive() | |||
| yield | |||
| if marker: | |||
| disable_higher_order_directive() | |||
| @@ -146,5 +146,5 @@ def test_dump_bn_train_mode(): | |||
| data = mge.tensor(np.random.random((10, 10, 10, 10))) | |||
| bn_train(data) | |||
| with pytest.raises(AssertionError): | |||
| with pytest.raises(RuntimeError): | |||
| bn_train.dump("test.mge") | |||
| @@ -17,7 +17,7 @@ import megengine.distributed as dist | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| import megengine.optimizer as optim | |||
| from megengine.autodiff import GradManager | |||
| from megengine.autodiff import Function, GradManager | |||
| from megengine.jit import trace | |||
| @@ -214,7 +214,7 @@ def test_remote_grad(trace_mode): | |||
| x = dist.functional.remote_recv(rank - 1) | |||
| y = m(x) | |||
| if rank != size - 1: | |||
| dist.functional.remote_send(y, dest_rank=rank + 1) | |||
| x = dist.functional.remote_send(y, dest_rank=rank + 1) | |||
| gm.backward() | |||
| else: | |||
| y = y.mean() | |||
| @@ -224,7 +224,7 @@ def test_remote_grad(trace_mode): | |||
| if trace_mode is not None: | |||
| train_func = trace(symbolic=trace_mode)(train_func) | |||
| for i in range(3): | |||
| for i in range(1): | |||
| train_func(x) | |||
| worker() | |||
| @@ -340,7 +340,6 @@ def test_broadcast_grad(trace_mode): | |||
| worker() | |||
| @pytest.mark.require_higher_order_directive() | |||
| def test_2nd_grad_with_manager(): | |||
| x_np = np.random.rand(10).astype("float32") | |||
| x = mge.tensor(x_np) | |||
| @@ -359,7 +358,6 @@ def test_2nd_grad_with_manager(): | |||
| ) | |||
| @pytest.mark.require_higher_order_directive() | |||
| def test_grad_manager_group(): | |||
| x_np = np.random.rand(10).astype("float32") | |||
| x = mge.tensor(x_np) | |||
| @@ -376,7 +374,6 @@ def test_grad_manager_group(): | |||
| x.grad = None | |||
| @pytest.mark.require_higher_order_directive() | |||
| def test_grad_manager_group_visibility(): | |||
| x_np = np.random.rand(10).astype("float32") | |||
| x = mge.tensor(x_np) | |||
| @@ -392,7 +389,6 @@ def test_grad_manager_group_visibility(): | |||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||
| @pytest.mark.require_higher_order_directive() | |||
| def test_grad_manager_visibility_by_order(): | |||
| x_np = np.random.rand(10).astype("float32") | |||
| x = mge.tensor(x_np) | |||
| @@ -410,7 +406,6 @@ def test_grad_manager_visibility_by_order(): | |||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||
| @pytest.mark.require_higher_order_directive() | |||
| @pytest.mark.parametrize("target", [F.cos, F.sin, lambda x: x * 2 + 1]) | |||
| def test_emulate_forward_mode_with_reverse_mode(target): | |||
| def jvp(inp, expr): | |||
| @@ -434,3 +429,43 @@ def test_emulate_forward_mode_with_reverse_mode(target): | |||
| np.testing.assert_almost_equal(y.numpy(), y1.numpy(), decimal=5) | |||
| np.testing.assert_almost_equal(dy.numpy(), dy1.numpy(), decimal=3) | |||
| def test_2nd_grad_with_custom_gradient(): | |||
| class MySin(Function): | |||
| def forward(self, x): | |||
| self.inp = x | |||
| x = mge.Tensor(x.numpy()) | |||
| y = F.sin(x) | |||
| return y | |||
| def backward(self, dy): | |||
| dx = F.cos(self.inp) * dy | |||
| return dx | |||
| class MyCos(Function): | |||
| def forward(self, x): | |||
| self.inp = x | |||
| x = mge.Tensor(x.numpy()) | |||
| y = F.cos(x) | |||
| return y | |||
| def backward(self, dy): | |||
| dx = -MySin()(self.inp) * dy | |||
| return dx | |||
| x_np = np.random.rand(10).astype("float32") | |||
| x = mge.tensor(x_np) | |||
| gm = GradManager().attach([x]) | |||
| gm2 = GradManager().attach([x]) | |||
| with gm: | |||
| with gm2: | |||
| y = MyCos()(x) | |||
| gm2.backward(y) | |||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||
| gm.backward(x.grad) | |||
| np.testing.assert_almost_equal( | |||
| x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5 | |||
| ) | |||
| @@ -7,8 +7,6 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import gc | |||
| import platform | |||
| import weakref | |||
| import numpy as np | |||
| import pytest | |||
| @@ -60,24 +58,20 @@ def test_dist_grad(): | |||
| def worker(): | |||
| rank = dist.get_rank() | |||
| if rank == 0: | |||
| grad = Grad() | |||
| x = as_tensor(x_np) | |||
| grad.wrt(x, callback=save_to(x)) | |||
| # need a placeholder to trace operator | |||
| remote_send(x, 1) | |||
| recv_x = remote_recv(1) | |||
| y = recv_x * recv_x | |||
| grad([y], [as_tensor(np.ones_like(x_np))]) | |||
| with Grad() as grad: | |||
| x = as_tensor(x_np) | |||
| grad.wrt(x, callback=save_to(x)) | |||
| # need a placeholder to trace operator | |||
| remote_send(x, 1) | |||
| recv_x = remote_recv(1) | |||
| y = recv_x * recv_x | |||
| grad([y], [as_tensor(np.ones_like(x_np))]) | |||
| np.testing.assert_almost_equal(x.grad.numpy(), x.numpy() * 2) | |||
| elif rank == 1: | |||
| grad = Grad() | |||
| recv_x = remote_recv(0) | |||
| remote_send(recv_x, 0) | |||
| grad([], []) | |||
| with Grad() as grad: | |||
| recv_x = remote_recv(0) | |||
| remote_send(recv_x, 0) | |||
| grad([], []) | |||
| worker() | |||
| @@ -86,11 +80,11 @@ def test_grad(): | |||
| x_np = np.random.rand(10).astype("float32") | |||
| x = as_tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| y = cos(x) | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| y = cos(x) | |||
| grad(y, as_tensor(np.ones_like(x_np))) | |||
| grad(y, as_tensor(np.ones_like(x_np))) | |||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np)) | |||
| @@ -98,12 +92,12 @@ def test_grad_2(): | |||
| x_np = np.random.rand(10).astype("float32") | |||
| x = as_tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| y = mul(x, x) | |||
| y = mul(y, y) | |||
| grad(y, as_tensor(np.ones_like(x_np))) | |||
| y = mul(x, x) | |||
| y = mul(y, y) | |||
| grad(y, as_tensor(np.ones_like(x_np))) | |||
| np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | |||
| @@ -113,32 +107,31 @@ def test_2nd_grad(): | |||
| x = as_tensor(x_np) | |||
| ones = as_tensor(np.ones_like(x_np)) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| grad._priority = -1 | |||
| grad2 = Grad().wrt(x, callback=save_to(x)) | |||
| grad2._priority = 0 | |||
| y = cos(x) | |||
| with Grad("grad2") as grad2: | |||
| with Grad("grad") as grad: | |||
| grad2.wrt(x, callback=save_to(x)) | |||
| grad.wrt(x, callback=save_to(x)) | |||
| y = cos(x) | |||
| grad(y, ones) | |||
| z = x.grad | |||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||
| grad(y, ones) | |||
| z = x.grad | |||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||
| x.grad = None | |||
| grad2(z, ones) | |||
| x.grad = None | |||
| grad2(z, ones) | |||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np), decimal=5) | |||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np), decimal=5) | |||
| def test_grad_with_tensor_wrapper(): | |||
| x_np = np.random.rand(10).astype("float32") | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| y = mul(x, x) | |||
| y = mul(y, y) | |||
| grad(y, mge.Tensor(np.ones_like(x_np))) | |||
| y = mul(x, x) | |||
| y = mul(y, y) | |||
| grad(y, mge.Tensor(np.ones_like(x_np))) | |||
| np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | |||
| @@ -162,18 +155,21 @@ def test_release(): | |||
| @check | |||
| def _(): | |||
| g = Grad().wrt(x) | |||
| y = x * x | |||
| g(y, dy) | |||
| with Grad() as g: | |||
| g.wrt(x) | |||
| y = x * x | |||
| g(y, dy) | |||
| @check | |||
| def _(): | |||
| with Grad().wrt(x): | |||
| with Grad() as g: | |||
| g.wrt(x) | |||
| pass | |||
| @check | |||
| def _(): | |||
| with Grad().wrt(x): | |||
| with Grad() as g: | |||
| g.wrt(x) | |||
| y = x * x | |||
| @@ -181,12 +177,12 @@ def test_grad_inplace(): | |||
| x_np = np.random.rand(10).astype("float32") | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| y = mul(x, x) | |||
| y *= y | |||
| grad(y, mge.Tensor(np.ones_like(x_np))) | |||
| y = mul(x, x) | |||
| y *= y | |||
| grad(y, mge.Tensor(np.ones_like(x_np))) | |||
| np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | |||
| @@ -196,11 +192,11 @@ def test_identity(): | |||
| dy_np = np.random.rand(*x.shape).astype("float32") | |||
| dy = mge.Tensor(dy_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| (y,) = apply(Identity(), x) | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| (y,) = apply(Identity(), x) | |||
| grad(y, dy) | |||
| grad(y, dy) | |||
| np.testing.assert_array_equal(x.grad.numpy(), dy_np) | |||
| @@ -220,15 +216,14 @@ def test_elemwise_add(): | |||
| refs["y"] = TensorWeakRef(y) | |||
| return x + y | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| z = f(x, y) | |||
| del y | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| z = f(x, y) | |||
| del y | |||
| for k, r in refs.items(): | |||
| assert r() is None | |||
| grad(z, dz) | |||
| for k, r in refs.items(): | |||
| assert r() is None | |||
| grad(z, dz) | |||
| np.testing.assert_almost_equal(x.grad.numpy(), dz_np.sum(0) * 2, decimal=5) | |||
| @@ -245,13 +240,12 @@ def test_elemwise_relu(): | |||
| refs["x"] = TensorWeakRef(x) | |||
| return relu(x) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| z = f(x) | |||
| assert refs["x"]() is None | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| z = f(x) | |||
| assert refs["x"]() is None | |||
| grad(z, dz) | |||
| grad(z, dz) | |||
| np.testing.assert_almost_equal(x.grad.numpy(), [2.0, 0]) | |||
| @@ -269,21 +263,21 @@ def test_reshape(): | |||
| x_np = np.random.rand(2, 5).astype("float32") | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| refs = {} | |||
| refs = {} | |||
| def f(x): | |||
| x = x * 1 | |||
| y = x.reshape(5, 2) | |||
| refs["x"] = TensorWeakRef(x) | |||
| return y | |||
| def f(x): | |||
| x = x * 1 | |||
| y = x.reshape(5, 2) | |||
| refs["x"] = TensorWeakRef(x) | |||
| return y | |||
| y = f(x) | |||
| for _, r in refs.items(): | |||
| assert r() is None | |||
| grad(y, F.ones_like(y)) | |||
| y = f(x) | |||
| for _, r in refs.items(): | |||
| assert r() is None | |||
| grad(y, F.ones_like(y)) | |||
| np.testing.assert_equal(np.ones((2, 5), dtype=np.float32), x.grad.numpy()) | |||
| @@ -291,21 +285,21 @@ def test_subtensor(): | |||
| x_np = np.random.rand(3, 3).astype("float32") | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| refs = {} | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| refs = {} | |||
| def f(x): | |||
| x = x * 1 | |||
| y = x[1:-1, :2] | |||
| refs["x"] = TensorWeakRef(x) | |||
| return y | |||
| def f(x): | |||
| x = x * 1 | |||
| y = x[1:-1, :2] | |||
| refs["x"] = TensorWeakRef(x) | |||
| return y | |||
| y = f(x) | |||
| for _, r in refs.items(): | |||
| assert r() is None | |||
| y = f(x) | |||
| for _, r in refs.items(): | |||
| assert r() is None | |||
| grad(y, F.ones_like(y)) | |||
| grad(y, F.ones_like(y)) | |||
| np.testing.assert_equal( | |||
| np.array([[0, 0, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32), x.grad.numpy() | |||
| ) | |||
| @@ -315,21 +309,21 @@ def test_IndexingMultiAxisVec(): | |||
| x_np = np.random.rand(3, 3).astype("float32") | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| refs = {} | |||
| refs = {} | |||
| def f(x): | |||
| x = x * 1 | |||
| y = x[[0, 2], [0, 2]] | |||
| refs["x"] = TensorWeakRef(x) | |||
| return y | |||
| def f(x): | |||
| x = x * 1 | |||
| y = x[[0, 2], [0, 2]] | |||
| refs["x"] = TensorWeakRef(x) | |||
| return y | |||
| y = f(x) | |||
| for _, r in refs.items(): | |||
| assert r() is None | |||
| grad(y, F.ones_like(y)) | |||
| y = f(x) | |||
| for _, r in refs.items(): | |||
| assert r() is None | |||
| grad(y, F.ones_like(y)) | |||
| np.testing.assert_equal( | |||
| np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy() | |||
| ) | |||
| @@ -339,21 +333,21 @@ def test_AxisAddRemove(): | |||
| x_np = np.random.rand(1, 5).astype("float32") | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| refs = {} | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| refs = {} | |||
| def f(x): | |||
| x = x * 1 | |||
| y = F.squeeze(F.expand_dims(x, 2), 0) | |||
| refs["x"] = TensorWeakRef(x) | |||
| return y | |||
| def f(x): | |||
| x = x * 1 | |||
| y = F.squeeze(F.expand_dims(x, 2), 0) | |||
| refs["x"] = TensorWeakRef(x) | |||
| return y | |||
| y = f(x) | |||
| for _, r in refs.items(): | |||
| assert r() is None | |||
| y = f(x) | |||
| for _, r in refs.items(): | |||
| assert r() is None | |||
| grad(y, F.ones_like(y)) | |||
| grad(y, F.ones_like(y)) | |||
| np.testing.assert_equal( | |||
| np.array([[1, 1, 1, 1, 1]], dtype=np.float32), x.grad.numpy() | |||
| ) | |||
| @@ -363,10 +357,11 @@ def test_Broadcast(): | |||
| x_np = np.random.rand(3, 3, 1).astype("float32") | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| y = F.broadcast_to(x, (3, 3, 10)) | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| y = F.broadcast_to(x, (3, 3, 10)) | |||
| grad(y, F.ones_like(y)) | |||
| grad(y, F.ones_like(y)) | |||
| np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) | |||
| @@ -374,10 +369,11 @@ def test_interpolate_fastpath(): | |||
| x_np = np.random.rand(3, 3, 32, 32).astype("float32") | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| y = F.vision.interpolate(x, size=(16, 16), mode="bilinear") | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| y = F.vision.interpolate(x, size=(16, 16), mode="bilinear") | |||
| grad(y, F.ones_like(y)) | |||
| grad(y, F.ones_like(y)) | |||
| np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy()) | |||
| @@ -385,10 +381,11 @@ def test_Reduce_sum(): | |||
| x_np = np.random.rand(3, 3).astype("float32") | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| y = x.sum(axis=0) | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| y = x.sum(axis=0) | |||
| grad(y, F.ones_like(y)) | |||
| grad(y, F.ones_like(y)) | |||
| np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy()) | |||
| @@ -396,10 +393,11 @@ def test_Reduce_mean(): | |||
| x_np = np.random.rand(3, 3).astype("float32") | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| y = x.mean(axis=0) | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| y = x.mean(axis=0) | |||
| grad(y, F.ones_like(y)) | |||
| grad(y, F.ones_like(y)) | |||
| np.testing.assert_equal(np.ones((3, 3), dtype=np.float32) / 3, x.grad.numpy()) | |||
| @@ -407,21 +405,21 @@ def test_addAxis(): | |||
| x_np = np.random.rand(3, 3).astype("float32") | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| refs = {} | |||
| refs = {} | |||
| def f(x): | |||
| x = x * 1 | |||
| y = F.expand_dims(x, [2, 3]) | |||
| refs["x"] = TensorWeakRef(x) | |||
| return y | |||
| def f(x): | |||
| x = x * 1 | |||
| y = F.expand_dims(x, [2, 3]) | |||
| refs["x"] = TensorWeakRef(x) | |||
| return y | |||
| y = f(x) | |||
| for _, r in refs.items(): | |||
| assert r() is None | |||
| y = f(x) | |||
| for _, r in refs.items(): | |||
| assert r() is None | |||
| grad(y, F.ones_like(y)) | |||
| grad(y, F.ones_like(y)) | |||
| np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy()) | |||
| @@ -429,21 +427,21 @@ def test_removeAxis(): | |||
| x_np = np.random.rand(3, 3, 1, 1).astype("float32") | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| refs = {} | |||
| refs = {} | |||
| def f(x): | |||
| x = x * 1 | |||
| y = F.squeeze(x, [2, 3]) | |||
| refs["x"] = TensorWeakRef(x) | |||
| return y | |||
| def f(x): | |||
| x = x * 1 | |||
| y = F.squeeze(x, [2, 3]) | |||
| refs["x"] = TensorWeakRef(x) | |||
| return y | |||
| y = f(x) | |||
| for _, r in refs.items(): | |||
| assert r() is None | |||
| y = f(x) | |||
| for _, r in refs.items(): | |||
| assert r() is None | |||
| grad(y, F.ones_like(y)) | |||
| grad(y, F.ones_like(y)) | |||
| np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy()) | |||
| @@ -452,11 +450,14 @@ def test_dot(): | |||
| x = mge.Tensor(x) | |||
| u = F.ones((2,)) | |||
| v = F.ones((2,)) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| def f(x): | |||
| return F.dot(u, F.matmul(x, v)) | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| def f(x): | |||
| return F.dot(u, F.matmul(x, v)) | |||
| y = f(x) | |||
| grad(y, F.ones_like(y)) | |||
| y = f(x) | |||
| grad(y, F.ones_like(y)) | |||
| np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy()) | |||
| @@ -267,25 +267,27 @@ def _gen_roi_inp(): | |||
| def test_roi_align(): | |||
| inp_feat, rois = _gen_roi_inp() | |||
| grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat)) | |||
| output_shape = (7, 7) | |||
| out_feat = F.vision.roi_align( | |||
| inp_feat, | |||
| rois, | |||
| output_shape=output_shape, | |||
| mode="average", | |||
| spatial_scale=1.0 / 4, | |||
| sample_points=2, | |||
| aligned=True, | |||
| ) | |||
| assert make_shape_tuple(out_feat.shape) == ( | |||
| rois.shape[0], | |||
| inp_feat.shape[1], | |||
| *output_shape, | |||
| ) | |||
| with Grad() as grad: | |||
| grad.wrt(inp_feat, callback=_save_to(inp_feat)) | |||
| output_shape = (7, 7) | |||
| out_feat = F.vision.roi_align( | |||
| inp_feat, | |||
| rois, | |||
| output_shape=output_shape, | |||
| mode="average", | |||
| spatial_scale=1.0 / 4, | |||
| sample_points=2, | |||
| aligned=True, | |||
| ) | |||
| assert make_shape_tuple(out_feat.shape) == ( | |||
| rois.shape[0], | |||
| inp_feat.shape[1], | |||
| *output_shape, | |||
| ) | |||
| grad(out_feat, tensor(F.ones_like(out_feat))) | |||
| grad(out_feat, tensor(F.ones_like(out_feat))) | |||
| assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | |||
| @@ -307,20 +309,23 @@ def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)): | |||
| def test_correlation(): | |||
| ##test case 0 check the grad shape | |||
| data1, data2 = _gen_correlation() | |||
| grad = Grad().wrt(data1, callback=_save_to(data1)) | |||
| out_feat = F.vision.correlation( | |||
| data1, | |||
| data2, | |||
| kernel_size=5, | |||
| max_displacement=4, | |||
| stride1=2, | |||
| stride2=2, | |||
| pad_size=2, | |||
| is_multiply=True, | |||
| ) | |||
| with Grad() as grad: | |||
| grad.wrt(data1, callback=_save_to(data1)) | |||
| out_feat = F.vision.correlation( | |||
| data1, | |||
| data2, | |||
| kernel_size=5, | |||
| max_displacement=4, | |||
| stride1=2, | |||
| stride2=2, | |||
| pad_size=2, | |||
| is_multiply=True, | |||
| ) | |||
| grad(out_feat, tensor(F.ones_like(out_feat))) | |||
| grad(out_feat, tensor(F.ones_like(out_feat))) | |||
| assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape) | |||
| ##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194 | |||
| @@ -391,32 +396,36 @@ def test_correlation(): | |||
| def test_roi_pooling(): | |||
| inp_feat, rois = _gen_roi_inp() | |||
| grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat)) | |||
| output_shape = (7, 7) | |||
| out_feat = F.vision.roi_pooling( | |||
| inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4, | |||
| ) | |||
| assert make_shape_tuple(out_feat.shape) == ( | |||
| rois.shape[0], | |||
| inp_feat.shape[1], | |||
| *output_shape, | |||
| ) | |||
| with Grad() as grad: | |||
| grad.wrt(inp_feat, callback=_save_to(inp_feat)) | |||
| output_shape = (7, 7) | |||
| out_feat = F.vision.roi_pooling( | |||
| inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4, | |||
| ) | |||
| assert make_shape_tuple(out_feat.shape) == ( | |||
| rois.shape[0], | |||
| inp_feat.shape[1], | |||
| *output_shape, | |||
| ) | |||
| grad(out_feat, tensor(F.ones_like(out_feat))) | |||
| grad(out_feat, tensor(F.ones_like(out_feat))) | |||
| assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | |||
| def test_adaptive_avg_pool2d(): | |||
| inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) | |||
| oshp = (2, 2) | |||
| grad = Grad().wrt(inp, callback=_save_to(inp)) | |||
| outp = F.adaptive_avg_pool2d(inp, oshp,) | |||
| assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||
| np.testing.assert_equal( | |||
| outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32) | |||
| ) | |||
| with Grad() as grad: | |||
| grad.wrt(inp, callback=_save_to(inp)) | |||
| outp = F.adaptive_avg_pool2d(inp, oshp,) | |||
| assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||
| np.testing.assert_equal( | |||
| outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32) | |||
| ) | |||
| grad(outp, tensor(F.ones_like(outp))) | |||
| grad(outp, tensor(F.ones_like(outp))) | |||
| assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) | |||
| np.testing.assert_equal( | |||
| inp.grad.numpy(), | |||
| @@ -439,14 +448,16 @@ def test_adaptive_avg_pool2d(): | |||
| def test_adaptive_max_pool2d(): | |||
| inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) | |||
| oshp = (2, 2) | |||
| grad = Grad().wrt(inp, callback=_save_to(inp)) | |||
| outp = F.adaptive_max_pool2d(inp, oshp,) | |||
| assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||
| np.testing.assert_equal( | |||
| outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32) | |||
| ) | |||
| with Grad() as grad: | |||
| grad.wrt(inp, callback=_save_to(inp)) | |||
| outp = F.adaptive_max_pool2d(inp, oshp,) | |||
| assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||
| np.testing.assert_equal( | |||
| outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32) | |||
| ) | |||
| grad(outp, tensor(F.ones_like(outp))) | |||
| grad(outp, tensor(F.ones_like(outp))) | |||
| assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) | |||
| np.testing.assert_equal( | |||
| inp.grad.numpy(), | |||
| @@ -351,7 +351,7 @@ def test_expand_dims_for_scalar(): | |||
| for axis in [1, -2, (1, 2), (-2, -3)]: | |||
| np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis) | |||
| np.testing.assert_raises(AssertionError, F.expand_dims, xx, axis) | |||
| np.testing.assert_raises(RuntimeError, F.expand_dims, xx, axis) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| @@ -9,6 +9,7 @@ | |||
| import inspect | |||
| import io | |||
| import itertools | |||
| import random | |||
| from tempfile import mkstemp | |||
| import numpy as np | |||
| @@ -25,7 +26,7 @@ from megengine.core.ops import builtin as ops | |||
| from megengine.core.ops.builtin import Elemwise | |||
| from megengine.core.tensor.utils import isscalar | |||
| from megengine.functional import exp, log | |||
| from megengine.jit import GraphOptimizationConfig, exclude_from_trace, trace | |||
| from megengine.jit import GraphOptimizationConfig, TraceError, exclude_from_trace, trace | |||
| from megengine.module import Module | |||
| from megengine.random import normal, uniform | |||
| from megengine.utils.naming import AutoNaming | |||
| @@ -464,36 +465,92 @@ def test_trace_warp_perspective(): | |||
| f(x, M) | |||
| def test_raise_on_trace(): | |||
| step_count = 0 | |||
| catch_count = 0 | |||
| bad_step = 10 | |||
| @pytest.mark.parametrize( | |||
| "normal_expr, mismatch_expr, reason", | |||
| [ | |||
| ("a + b + c", "a + b - c", "operator mismatch"), | |||
| ("a + b + 1", "a + b + 2", "tensors not equals"), | |||
| ("((a + b), (b + c))[0]", "a + b", "mismature end"), | |||
| ("a + b + c", "c + (a + b)", "expect internal node, got external"), | |||
| ("c + (a + b)", "a + b + c", "expect external node, got internal"), | |||
| ("a + b + c", "a + b + c + c", "too many instructions"), | |||
| ("((a + b), (b + c))[1]", "((a + b), (b + c))[0]", "data unreadable"), | |||
| ("((a + b), (b + c))[1] + a", "((a + b), (b + c))[0] + a", "input id mismatch"), | |||
| ], | |||
| ) | |||
| def test_trace_mismatch(normal_expr, mismatch_expr, reason): | |||
| a = tensor([1, 2, 3, 4]) | |||
| b = tensor([5, 6, 7, 8]) | |||
| c = tensor([9, 0, 1, 2]) | |||
| mismatch = False | |||
| @trace(symbolic=True) | |||
| def fn(a, b, c): | |||
| if not mismatch: | |||
| result = eval(normal_expr) | |||
| else: | |||
| result = eval(mismatch_expr) | |||
| return result | |||
| for i in range(20): | |||
| try: | |||
| d = fn(a, b, c) | |||
| except TraceError as e: | |||
| assert mismatch | |||
| assert str(e) == "trace error because {}".format(reason) | |||
| except: | |||
| pytest.fail("unexpected trace error") | |||
| else: | |||
| assert not mismatch | |||
| np.testing.assert_equal(d.numpy(), eval(normal_expr).numpy()) | |||
| mismatch = random.random() > 0.8 | |||
| class CatchMe(Exception): | |||
| pass | |||
| def test_exception_in_trace(): | |||
| a = tensor([1, 2, 3, 4]) | |||
| b = tensor([5, 6, 7, 8]) | |||
| c = tensor([9, 0, 1, 2]) | |||
| @trace | |||
| def add_abc(a, b, c): | |||
| ps = a + b | |||
| result = ps + c | |||
| if step_count == bad_step: | |||
| raise CatchMe("catch me") | |||
| mismatch = False | |||
| exc = Exception() | |||
| @trace(symbolic=True) | |||
| def fn(a, b, c): | |||
| result = a + b | |||
| if not mismatch: | |||
| result += c | |||
| else: | |||
| raise exc | |||
| return result | |||
| for i in range(100): | |||
| for i in range(20): | |||
| try: | |||
| d = add_abc(a, b, c) | |||
| except CatchMe as e: | |||
| catch_count += 1 | |||
| d = fn(a, b, c) | |||
| except TraceError as e: | |||
| pytest.fail("unexpected trace error") | |||
| except Exception as e: | |||
| assert mismatch | |||
| assert e is exc | |||
| else: | |||
| assert not mismatch | |||
| np.testing.assert_equal(d.numpy(), (a + b + c).numpy()) | |||
| step_count += 1 | |||
| mismatch = random.random() > 0.8 | |||
| assert catch_count == 1 | |||
| def test_graph_error(): | |||
| a = tensor(np.arange(8).reshape((2, 4))) | |||
| b = tensor(np.arange(8).reshape((2, 4))) | |||
| @trace(symbolic=True) | |||
| def fn(a, b): | |||
| return a + b | |||
| fn(a, b) | |||
| with pytest.raises(RuntimeError): | |||
| fn(a, b.transpose()) | |||
| fn(a, b) | |||
| @pytest.mark.parametrize("trace_mode", [False, True]) | |||
| @@ -653,9 +710,10 @@ def test_trace_jit_config(): | |||
| x = tensor(2) | |||
| y = func(x) | |||
| func._compile() | |||
| y = func(x) | |||
| # func._compile() | |||
| options = func._graph.options | |||
| options = func._trace.options | |||
| mapping = {None: 0, False: 1, True: 2} | |||
| assert options.graph_opt.jit == 0 | |||
| assert options.graph_opt.jit_config.fuse_dimshuffle == mapping[fuse_dimshuffle] | |||
| @@ -82,9 +82,10 @@ def test_tqt(): | |||
| x = mge.tensor(x, dtype="float32") | |||
| s = mge.tensor(s, dtype="float32") | |||
| g_y = mge.tensor(g_y, dtype="float32") | |||
| grad = Grad().wrt(x, s, callback=cb) | |||
| y = tqt_forward(-127, 127, x, s) | |||
| grad(y, g_y) | |||
| with Grad() as grad: | |||
| grad.wrt(x, s, callback=cb) | |||
| y = tqt_forward(-127, 127, x, s) | |||
| grad(y, g_y) | |||
| g_x, g_s = g | |||
| np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5) | |||
| @@ -131,14 +132,16 @@ def test_fakequant(): | |||
| # test backward | |||
| x = tensor(inp_data, dtype=np.float32) | |||
| grad = Grad().wrt(x, callback=_save_to(x)) | |||
| y = fake_quant_tensor(x, qparams) | |||
| grad(y, tensor(F.ones_like(x))) | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=_save_to(x)) | |||
| y = fake_quant_tensor(x, qparams) | |||
| grad(y, tensor(F.ones_like(x))) | |||
| x1 = tensor(inp_data, dtype=np.float32) | |||
| grad = Grad().wrt(x1, callback=_save_to(x1)) | |||
| y1 = fake_quant_tensor_gt(x1, scale, zero_point, qmin, qmax) | |||
| grad(y1, tensor(F.ones_like(x1))) | |||
| with Grad() as grad: | |||
| grad.wrt(x1, callback=_save_to(x1)) | |||
| y1 = fake_quant_tensor_gt(x1, scale, zero_point, qmin, qmax) | |||
| grad(y1, tensor(F.ones_like(x1))) | |||
| assert np.allclose(x.grad.numpy(), x1.grad.numpy()) | |||
| assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape) | |||
| @@ -237,9 +240,10 @@ def test_lsq(): | |||
| grad_s = mge.tensor(grad_s, dtype="float32") | |||
| g_y = mge.tensor(g_y, dtype="float32") | |||
| grad = Grad().wrt(x, s, callback=cb) | |||
| y = lsq_forward(-127, 127, x, s, zero_point, grad_s) | |||
| grad(y, g_y) | |||
| with Grad() as grad: | |||
| grad.wrt(x, s, callback=cb) | |||
| y = lsq_forward(-127, 127, x, s, zero_point, grad_s) | |||
| grad(y, g_y) | |||
| g_x, g_s = g | |||
| np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-7, atol=1e-7) | |||
| @@ -430,9 +430,10 @@ def test_ShuffleRNG(): | |||
| n, m = 6, 3 | |||
| arr = np.arange(n * m) | |||
| out0 = Tensor(arr, dtype="float32") | |||
| grad = Grad().wrt(out0, callback=cb) | |||
| random.shuffle(out0) | |||
| grad(out0, F.ones_like(out0)) | |||
| with Grad() as grad: | |||
| grad.wrt(out0, callback=cb) | |||
| random.shuffle(out0) | |||
| grad(out0, F.ones_like(out0)) | |||
| m1 = RNG(seed=111, device="xpu0") | |||
| m2 = RNG(seed=111, device="xpu1") | |||
| m3 = RNG(seed=222, device="xpu0") | |||