GitOrigin-RevId: 2dd4e460ac
tags/v1.2.0
| @@ -301,7 +301,7 @@ class GradManager: | |||
| if tensor is None: | |||
| return | |||
| def callback(_, grad, callbacks=spec.callbacks): | |||
| def callback(grad, callbacks=spec.callbacks): | |||
| for cb in callbacks: | |||
| grad = cb(tensor, grad) | |||
| self._gradients[id(tensor)] = grad | |||
| @@ -16,6 +16,7 @@ import numpy as np | |||
| import megengine as mge | |||
| from .._imperative_rt import core2 | |||
| from ..ops.builtin import Elemwise, OpDef, RemoteSend | |||
| from ..ops.special import Const | |||
| from ..tensor.core import TensorBase, TensorWrapperBase, apply | |||
| @@ -418,3 +419,28 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): | |||
| @apply.register() | |||
| def _(op: Const, *_: typing.Optional[Tracer]): | |||
| return None | |||
| class Grad: | |||
| def __init__(self): | |||
| self._impl = core2.GradKey() | |||
| def wrt(self, *tensors, callback=None): | |||
| for x in tensors: | |||
| self._impl.attach(x, callback) | |||
| return self | |||
| def __call__(self, ys, dys): | |||
| from collections.abc import Sequence | |||
| if not isinstance(ys, Sequence): | |||
| ys = [ys] | |||
| if not isinstance(dys, Sequence): | |||
| dys = [dys] | |||
| core2.backward(self._impl, ys, dys) | |||
| def __enter__(self): | |||
| return self | |||
| def __exit__(self, _1, _2, _3): | |||
| del self._impl | |||
| @@ -6,11 +6,18 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from .._imperative_rt.core2 import Tensor | |||
| from ..tensor.core import OpBase, TensorBase, apply | |||
| class Const(OpBase): | |||
| class Const: | |||
| def __init__(self, value=None, *, dtype=None, device=None): | |||
| self.value = value | |||
| self.value = np.asarray(value, dtype=dtype) | |||
| self.dtype = dtype | |||
| self.device = device | |||
| def __call__(self, *reference): | |||
| Wrapper = type(reference[0]) | |||
| return (Wrapper(self.value, self.dtype, self.device),) | |||
| @@ -13,9 +13,17 @@ import sys | |||
| import typing | |||
| from abc import ABC | |||
| from .._imperative_rt.core2 import apply as apply2 | |||
| from .multipledispatch import Dispatcher | |||
| def apply_op(op, *args): | |||
| Wrapper = type(args[0]) | |||
| args = [arg._tensor for arg in args] | |||
| results = apply2(op, *args) | |||
| return tuple(map(Wrapper, results)) | |||
| class OpBase(ABC): | |||
| def __call__(self, *args): | |||
| return apply(self, *args) | |||
| @@ -10,10 +10,10 @@ from typing import Iterable | |||
| import numpy as np | |||
| from .._imperative_rt.core2 import Tensor, apply | |||
| from .._trace_option import use_symbolic_shape | |||
| from ..ops import builtin | |||
| from ..ops.special import Const | |||
| from .core import TensorBase, TensorWrapperBase, apply | |||
| from .utils import astensor1d, isscalar, make_shape_tuple | |||
| @@ -149,13 +149,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
| return True | |||
| def get_index(i): | |||
| if not isinstance(i, (TensorBase, TensorWrapperBase)): | |||
| if not isinstance(i, (Tensor)): | |||
| if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | |||
| (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | |||
| else: | |||
| (i,) = Const(i, dtype=np.int32, device=inp.device)(inp) | |||
| return i | |||
| assert isinstance(i, (TensorBase, TensorWrapperBase)) | |||
| assert isinstance(i, Tensor) | |||
| if i.dtype != np.bool_: | |||
| return i | |||
| _, ind = apply(builtin.CondTake(), i, i) | |||
| @@ -198,8 +198,8 @@ def try_condtake(tensor, index): | |||
| return [] | |||
| if isinstance(index, np.ndarray): | |||
| (index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) | |||
| assert isinstance(index, (TensorBase, TensorWrapperBase)) | |||
| if not isinstance(tensor, (TensorWrapperBase, TensorBase)): | |||
| assert isinstance(index, Tensor) | |||
| if not isinstance(tensor, Tensor): | |||
| raise TypeError("input must be a tensor") | |||
| if tensor.device != index.device: | |||
| raise ValueError( | |||
| @@ -227,7 +227,7 @@ def getitem(tensor, index): | |||
| op = builtin.IndexingMultiAxisVec(items=items) | |||
| (result,) = apply(op, tensor, *tensors) | |||
| if ret_scalar: | |||
| result.__wrapped__._data._isscalar = True | |||
| result.setscalar() | |||
| return result | |||
| @@ -239,7 +239,7 @@ def setitem(tensor, index, value): | |||
| if index.shape[0] == 0: | |||
| return tensor | |||
| tensor = tensor.reshape(-1) | |||
| if not isinstance(value, (TensorBase, TensorWrapperBase)): | |||
| if not isinstance(value, Tensor): | |||
| op = Const(value, dtype=tensor.dtype, device=tensor.device) | |||
| (value,) = op(tensor) | |||
| tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | |||
| @@ -250,6 +250,7 @@ def setitem(tensor, index, value): | |||
| op = builtin.Subtensor(items=items) | |||
| else: | |||
| op = builtin.IndexingMultiAxisVec(items=items) | |||
| (tmp_result,) = apply(op, tensor, *tensors) | |||
| # XXX: broadcast can always be applied even if shapes are equal | |||
| @@ -8,19 +8,20 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import abc | |||
| import collections | |||
| from typing import Union | |||
| import numpy as np | |||
| from .._imperative_rt.common import CompNode | |||
| from .._imperative_rt.core2 import Tensor, apply | |||
| from .._trace_option import use_symbolic_shape | |||
| from ..ops import builtin | |||
| from ..ops.builtin import Elemwise, GetVarShape | |||
| from ..ops.special import Const | |||
| from . import utils | |||
| from .core import OpBase, TensorBase, TensorWrapperBase, apply | |||
| from .core import OpBase, TensorBase, TensorWrapperBase | |||
| from .indexing import getitem as _getitem | |||
| from .indexing import setitem as _setitem | |||
| from .raw_tensor import RawTensor, as_raw_tensor | |||
| from .tensor import Tensor | |||
| from .utils import isscalar | |||
| from .utils import make_shape_tuple as _make_shape_tuple | |||
| from .utils import setscalar | |||
| @@ -41,6 +42,7 @@ def _elwise(*args, mode): | |||
| ) | |||
| args = utils.convert_inputs(*args) | |||
| (result,) = apply(op, *args) | |||
| _isscalar = True | |||
| for i in args: | |||
| if isscalar(i) == False: | |||
| @@ -84,9 +86,7 @@ def _reshape(x, shape): | |||
| if unspec_axis is not None: | |||
| raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) | |||
| unspec_axis = i | |||
| shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) | |||
| if unspec_axis is None: | |||
| op = builtin.Reshape() | |||
| else: | |||
| @@ -181,7 +181,6 @@ def _reduce(mode): | |||
| elif isinstance(axis, collections.abc.Iterable): | |||
| axis = list(axis) | |||
| axis.sort(reverse=True) | |||
| for ai in axis: | |||
| op = builtin.Reduce(mode=mode, axis=ai) | |||
| (data,) = apply(op, data) | |||
| @@ -221,10 +220,7 @@ def _todo(*_): | |||
| def _expand_args(args): | |||
| if len(args) == 1: | |||
| if isinstance( | |||
| args[0], | |||
| (collections.abc.Sequence, TensorBase, TensorWrapperBase, np.ndarray), | |||
| ): | |||
| if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),): | |||
| args = args[0] | |||
| return args | |||
| @@ -240,9 +236,8 @@ class ArrayMethodMixin(abc.ABC): | |||
| return self.numpy().astype(dtype) | |||
| def __array_wrap__(self, array): | |||
| return TensorWrapper( | |||
| as_raw_tensor(array, dtype=array.dtype, device=self.device) | |||
| ) | |||
| Wrapper = type(self) | |||
| return Wrapper(array, dtype=array.dtype, device=self.device) | |||
| @abc.abstractmethod | |||
| def _reset(self, other): | |||
| @@ -253,7 +248,11 @@ class ArrayMethodMixin(abc.ABC): | |||
| pass | |||
| @abc.abstractproperty | |||
| def shape(self) -> tuple: | |||
| def shape(self) -> Union[tuple, Tensor]: | |||
| pass | |||
| @abc.abstractproperty | |||
| def _tuple_shape(self) -> tuple: | |||
| pass | |||
| @abc.abstractmethod | |||
| @@ -331,7 +330,7 @@ class ArrayMethodMixin(abc.ABC): | |||
| __complex__ = lambda self: complex(self.item()) | |||
| def __len__(self): | |||
| shape = self.__wrapped__.shape | |||
| shape = self._tuple_shape | |||
| if shape: | |||
| return int(shape[0]) | |||
| raise TypeError("ndim is 0") | |||
| @@ -352,7 +351,7 @@ class ArrayMethodMixin(abc.ABC): | |||
| @property | |||
| def ndim(self): | |||
| shape = self.__wrapped__.shape | |||
| shape = self._tuple_shape | |||
| if shape is None: | |||
| raise ValueError("unkown ndim") | |||
| return len(shape) | |||
| @@ -480,22 +479,52 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase): | |||
| self.__wrapped__._swap_out() | |||
| class TensorWrapper(GenericTensorWrapper): | |||
| def __init__(self, data, dtype=None, device=None): | |||
| if isinstance(data, TensorWrapperBase): | |||
| data = data.__wrapped__ | |||
| elif not isinstance(data, TensorBase): | |||
| assert data is not None, "Cannot init a tensor with data as None" | |||
| data = Tensor(as_raw_tensor(data, dtype=dtype, device=device)) | |||
| super().__init__(data) | |||
| class TensorWrapper(ArrayMethodMixin, TensorBase): | |||
| def __init__(self, data, dtype=None, device=None, isscalar=False): | |||
| self._isscalar = isscalar | |||
| if isinstance(data, Tensor): | |||
| self._tensor = data | |||
| else: | |||
| if device is None: | |||
| device = CompNode._get_default_device() | |||
| self._tensor = Tensor(data, dtype, device) | |||
| def _reset(self, other): | |||
| if isinstance(other, TensorWrapperBase): | |||
| self.__wrapped__ = other.__wrapped__ | |||
| elif isinstance(other, TensorBase): | |||
| self.__wrapped__ = other | |||
| else: | |||
| self._reset(type(self)(other, dtype=self.dtype, device=self.device)) | |||
| if not isinstance(other, __class__): | |||
| raise TypeError(type(other)) | |||
| self._tensor = other._tensor | |||
| return self | |||
| @property | |||
| def dtype(self): | |||
| return self._tensor.dtype | |||
| @property | |||
| def shape(self): | |||
| if self._isscalar: | |||
| return () | |||
| shape = self._tensor.shape | |||
| if shape == () or not use_symbolic_shape(): | |||
| return shape | |||
| return apply(GetVarShape(), self)[0] | |||
| @property | |||
| def device(self): | |||
| return self._tensor.device | |||
| def numpy(self): | |||
| if self._isscalar: | |||
| return self._tensor.numpy().squeeze() | |||
| return self._tensor.numpy() | |||
| def _drop(self): | |||
| self._tensor._drop() | |||
| def _swap_in(self): | |||
| self._tensor._swap_in() | |||
| def _swap_out(self): | |||
| self._tensor._swap_out() | |||
| def __repr__(self): | |||
| piece = "Tensor(" | |||
| @@ -11,9 +11,10 @@ from typing import Iterable, Union | |||
| import numpy as np | |||
| from .._imperative_rt.core2 import Tensor, apply | |||
| from ..ops import builtin | |||
| from ..ops.special import Const | |||
| from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||
| from ..tensor.core import OpBase, TensorBase, TensorWrapperBase | |||
| from .dtype import is_equal, is_quantize | |||
| _enable_convert_inputs = True | |||
| @@ -109,7 +110,7 @@ def dtype_promotion(inputs): | |||
| def get_device(inputs): | |||
| device = None | |||
| for i in inputs: | |||
| if isinstance(i, (TensorWrapperBase, TensorBase)): | |||
| if isinstance(i, Tensor): | |||
| if device is None: | |||
| device = i.device | |||
| elif device != i.device: | |||
| @@ -126,30 +127,31 @@ def concatenate(inputs, axis=0, *, device=None): | |||
| return convert_single_value(x, inputs, dtype=dtype) | |||
| inputs = tuple(map(convert, inputs)) | |||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inputs) | |||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) | |||
| return result | |||
| def astype(x, dtype): | |||
| dtype = np.dtype(dtype) | |||
| if not is_equal(x.dtype, dtype): | |||
| isscalar = x.__wrapped__._data._isscalar | |||
| isscalar = x.isscalar() | |||
| (x,) = apply(builtin.TypeCvt(dtype=dtype), x) | |||
| x.__wrapped__._data._isscalar = isscalar | |||
| if isscalar: | |||
| x.setscalar() | |||
| return x | |||
| def convert_single_value(v, inputs, *, dtype=None, device=None): | |||
| tensors = [i for i in inputs if isinstance(i, (TensorBase, TensorWrapperBase))] | |||
| tensors = [i for i in inputs if isinstance(i, Tensor)] | |||
| assert len(tensors) > 0 | |||
| if isinstance(v, (TensorWrapperBase, TensorBase)): | |||
| if isinstance(v, (TensorWrapperBase, Tensor)): | |||
| v = astype(v, v.dtype if is_quantize(v.dtype) else dtype) | |||
| else: | |||
| (v,) = Const(v, dtype=dtype, device=device)(*tensors) | |||
| return v | |||
| def convert_inputs(*args: TensorBase): | |||
| def convert_inputs(*args: Tensor): | |||
| if not _enable_convert_inputs: | |||
| return args | |||
| @@ -167,7 +169,7 @@ def convert_inputs(*args: TensorBase): | |||
| def result_type(*args): | |||
| dtypes = [] | |||
| for i in args: | |||
| if isinstance(i, (TensorWrapperBase, TensorBase)): | |||
| if isinstance(i, Tensor): | |||
| dtypes.append(i.dtype) | |||
| continue | |||
| try: | |||
| @@ -178,25 +180,16 @@ def result_type(*args): | |||
| def isscalar(x): | |||
| if isinstance(x, TensorWrapperBase): | |||
| x = x.__wrapped__ | |||
| if hasattr(x, "_isscalar"): | |||
| return x._isscalar | |||
| if isinstance(x, TensorBase): | |||
| return x._data._isscalar | |||
| if isinstance(x, Tensor): | |||
| return x.isscalar() | |||
| return np.isscalar(x) | |||
| def setscalar(x): | |||
| if isinstance(x, TensorWrapperBase): | |||
| x = x.__wrapped__ | |||
| if hasattr(x, "_isscalar"): | |||
| x._isscalar = True | |||
| elif isinstance(x, TensorBase): | |||
| x._data._isscalar = True | |||
| if isinstance(x, Tensor): | |||
| x.setscalar() | |||
| else: | |||
| raise NotImplementedError("Unsupport type {}".format(type(x))) | |||
| @@ -215,25 +208,24 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
| else: | |||
| if ndim != 0 and ndim != 1: | |||
| raise ValueError("ndim != 1 or 0, get : %d" % ndim) | |||
| if not isinstance(x, (TensorBase, TensorWrapperBase)): | |||
| if not isinstance(x, Tensor): | |||
| (x,) = Const(x, dtype=dtype, device=device)(*reference) | |||
| return x | |||
| if not isinstance(x, collections.abc.Sequence): | |||
| raise TypeError | |||
| if any(isinstance(i, (TensorBase, TensorWrapperBase)) for i in x): | |||
| if any(isinstance(i, Tensor) for i in x): | |||
| x = concatenate(x, device=device) | |||
| if dtype is not None: | |||
| x = astype(x, dtype) | |||
| return x | |||
| (x,) = Const(x, dtype=dtype, device=device)(*reference) | |||
| return x | |||
| def _expand_int(s, i): | |||
| if isinstance(i, (TensorBase, TensorWrapperBase)): | |||
| if isinstance(i, Tensor): | |||
| i_np = i.numpy() | |||
| if i_np.ndim == 0: | |||
| s.append(int(i_np)) | |||
| @@ -8,6 +8,7 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Optional, Tuple | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn | |||
| from ..core.autodiff.grad import ( | |||
| Tracer, | |||
| @@ -17,7 +18,6 @@ from ..core.autodiff.grad import ( | |||
| tracer_apply, | |||
| ) | |||
| from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
| from ..core.tensor.core import apply | |||
| from ..core.tensor.tensor import Tensor, tensor_apply | |||
| from ..device import get_default_device | |||
| from ..tensor import tensor | |||
| @@ -39,71 +39,6 @@ __all__ = [ | |||
| ] | |||
| @apply.register() | |||
| def _(op: RemoteSend, *args: Tensor): | |||
| ret = tensor_apply(op, *args) | |||
| # set extra information | |||
| tracer_set = dict() | |||
| for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))): | |||
| tracer_set[k.name] = True | |||
| # check tracer_set in remote_recv | |||
| get_client().set_remote_tracer(op.key, tracer_set) | |||
| return ret | |||
| @builtin_op_get_backward_fn.register(RemoteSend) | |||
| def _(op: RemoteSend, inputs, outputs, input_requires_grad): | |||
| def backward(*args): | |||
| return [ | |||
| remote_recv( | |||
| op.rank_to, | |||
| inputs[0].shape, | |||
| inputs[0].dtype, | |||
| device=str(inputs[0].device), | |||
| inp=inputs[0], | |||
| ) | |||
| ] | |||
| return backward, [True] | |||
| @get_op_has_grad_fn.register(RemoteSend) | |||
| def _(op: RemoteSend): | |||
| def has_grad(opnode, reached): | |||
| return get_client().check_is_grad(op.key) | |||
| return has_grad | |||
| @check_backward_allow_noinput.register(RemoteSend) | |||
| def _(op: RemoteSend): | |||
| return True | |||
| @builtin_op_get_backward_fn.register(RemoteRecv) | |||
| def _(op: RemoteRecv, inputs, outputs, input_requires_grad): | |||
| def backward(*output_grads): | |||
| return [remote_send(output_grads[0], op.rank_from)] | |||
| return backward, [True] | |||
| @get_op_has_grad_fn.register(RemoteRecv) | |||
| def _(op: RemoteRecv): | |||
| def has_grad(opnode, reached): | |||
| ret = False | |||
| for v in opnode.outputs: | |||
| if v() in reached: | |||
| ret = True | |||
| break | |||
| get_client().set_is_grad(op.key, ret) | |||
| return ret | |||
| return has_grad | |||
| def collective_comm(inp, mode, group, device): | |||
| """Helper function for applying collective communication functions.""" | |||
| assert isinstance(group, Group) | |||
| @@ -17,8 +17,8 @@ import numpy as np | |||
| from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager | |||
| from megengine.device import get_default_device, get_device_count | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core.ops.builtin import ParamPackConcat, ParamPackSplit | |||
| from ..core.tensor.core import apply | |||
| from ..functional.utils import copy | |||
| from ..tensor import Tensor | |||
| from ..utils.future import Future | |||
| @@ -228,7 +228,6 @@ class AllreduceCallback: | |||
| self._packing_size[dtype] = 0 | |||
| def __call__(self, param, grad): | |||
| param = param.__wrapped__ | |||
| gm = get_backwarding_grad_manager() | |||
| assert isinstance(gm, GradManager) | |||
| if gm not in self._marked_gm: | |||
| @@ -9,10 +9,10 @@ | |||
| # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | |||
| import functools | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core.ops import builtin | |||
| from ..core.ops.builtin import Elemwise | |||
| from ..core.tensor import megbrain_graph, utils | |||
| from ..core.tensor.core import apply | |||
| from ..core.tensor.utils import isscalar, setscalar | |||
| from ..device import get_default_device | |||
| from ..jit.tracing import is_tracing | |||
| @@ -12,10 +12,11 @@ import math | |||
| import numbers | |||
| from typing import Optional, Sequence, Tuple, Union | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core.ops import builtin | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import utils | |||
| from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | |||
| from ..core.tensor.core import TensorBase, TensorWrapperBase | |||
| from ..tensor import Tensor | |||
| from .elemwise import clip, exp, log, log1p | |||
| from .tensor import reshape, squeeze | |||
| @@ -10,12 +10,12 @@ | |||
| from typing import Optional, Sequence, Tuple, Union | |||
| from ..core._imperative_rt import CompNode | |||
| from ..core._imperative_rt.core2 import Tensor, apply | |||
| from ..core._trace_option import use_symbolic_shape | |||
| from ..core.ops import builtin | |||
| from ..core.ops.builtin import BatchNorm | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import megbrain_graph, utils | |||
| from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | |||
| from ..core.tensor.utils import astensor1d | |||
| from ..distributed import WORLD, is_distributed | |||
| from ..jit.tracing import is_tracing | |||
| @@ -1565,9 +1565,7 @@ def indexing_one_hot( | |||
| [1.] | |||
| """ | |||
| assert isinstance( | |||
| src, (TensorWrapperBase, TensorBase) | |||
| ), "src must be of Tensor type" | |||
| assert isinstance(src, Tensor), "src must be of Tensor type" | |||
| op = builtin.IndexingOneHot(axis=axis) | |||
| index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device) | |||
| (result,) = apply(op, src, index) | |||
| @@ -8,8 +8,8 @@ | |||
| # pylint: disable=too-many-lines | |||
| from typing import Tuple, Union | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core.ops import builtin | |||
| from ..core.tensor.core import apply | |||
| from ..tensor import Tensor | |||
| from .debug_param import get_conv_execution_strategy | |||
| from .types import _pair, _pair_nonzero | |||
| @@ -14,10 +14,10 @@ from typing import Iterable, List, Optional, Sequence, Tuple, Union | |||
| import numpy as np | |||
| from ..core._imperative_rt import CompNode | |||
| from ..core._imperative_rt.core2 import Tensor, apply | |||
| from ..core._wrap import device as as_device | |||
| from ..core.ops import builtin | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | |||
| from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis | |||
| from ..core.tensor.utils import ( | |||
| astensor1d, | |||
| @@ -611,11 +611,11 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: | |||
| """ | |||
| x, y = convert_inputs(x, y) | |||
| if not isinstance(x, (TensorWrapperBase, TensorBase)): | |||
| if not isinstance(x, Tensor): | |||
| raise TypeError("input x must be a tensor") | |||
| if not isinstance(y, (TensorWrapperBase, TensorBase)): | |||
| if not isinstance(y, Tensor): | |||
| raise TypeError("input y must be a tensor") | |||
| if not isinstance(mask, (TensorWrapperBase, TensorBase)): | |||
| if not isinstance(mask, Tensor): | |||
| raise TypeError("mask must be a tensor") | |||
| if mask.dtype != np.bool_: | |||
| raise ValueError("mask must be bool") | |||
| @@ -668,9 +668,9 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor: | |||
| [1. 4.] [0 3] | |||
| """ | |||
| if not isinstance(x, (TensorWrapperBase, TensorBase)): | |||
| if not isinstance(x, Tensor): | |||
| raise TypeError("input must be a tensor") | |||
| if not isinstance(mask, (TensorWrapperBase, TensorBase)): | |||
| if not isinstance(mask, Tensor): | |||
| raise TypeError("mask must be a tensor") | |||
| if mask.dtype != np.bool_: | |||
| raise ValueError("mask must be bool") | |||
| @@ -11,10 +11,10 @@ from typing import Iterable, Union | |||
| import numpy as np | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core._wrap import device as as_device | |||
| from ..core.ops.builtin import Copy, Identity | |||
| from ..core.tensor import Tensor | |||
| from ..core.tensor.core import apply | |||
| from ..tensor import Tensor | |||
| from .math import topk as _topk | |||
| from .tensor import broadcast_to, transpose | |||
| @@ -10,9 +10,9 @@ from typing import Iterable, Optional | |||
| from .. import Tensor | |||
| from ..core._imperative_rt import invoke_op | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core.ops.builtin import GaussianRNG, UniformRNG | |||
| from ..core.tensor import utils | |||
| from ..core.tensor.core import apply | |||
| from .rng import _random_seed_generator | |||
| __all__ = ["normal", "uniform"] | |||
| @@ -10,26 +10,66 @@ | |||
| import collections | |||
| from .core import Tensor as _Tensor | |||
| from .core.ops.builtin import Copy | |||
| from .core.tensor.core import apply | |||
| import numpy as np | |||
| from .core._imperative_rt import CompNode | |||
| from .core._imperative_rt.core2 import Tensor as _Tensor | |||
| from .core._imperative_rt.core2 import apply | |||
| from .core._trace_option import use_symbolic_shape | |||
| from .core.ops.builtin import Copy, GetVarShape | |||
| from .core.tensor.raw_tensor import as_device | |||
| from .core.tensor.tensor_wrapper import ArrayMethodMixin | |||
| from .device import _valid_device, get_default_device | |||
| from .utils.deprecation import deprecated | |||
| class Tensor(_Tensor): | |||
| class Tensor(_Tensor, ArrayMethodMixin): | |||
| grad = None | |||
| dmap_callback = None | |||
| q_dict = {"mode": None, "scale": None, "zero_point": None} | |||
| def __init__(self, data, dtype=None, device=None): | |||
| def __new__(cls, data, dtype=None, device=None): | |||
| if device is None: | |||
| device = get_default_device() | |||
| self.q_dict = {"mode": None, "scale": None, "zero_point": None} | |||
| super().__init__(data, dtype=dtype, device=device) | |||
| cn = get_default_device() | |||
| elif isinstance(device, str): | |||
| if cls.dmap_callback is not None: | |||
| cn = CompNode(cls.dmap_callback(device)) | |||
| else: | |||
| cn = CompNode(device) | |||
| else: | |||
| assert isinstance(device, CompNode) | |||
| cn = device | |||
| if isinstance(data, _Tensor): | |||
| obj = _Tensor.__new__(cls, data) | |||
| else: | |||
| obj = _Tensor.__new__(cls, data, dtype, cn) | |||
| return obj | |||
| @property | |||
| def shape(self): | |||
| shape = super().shape | |||
| if shape == () or not use_symbolic_shape(): | |||
| return shape | |||
| return apply(GetVarShape(), self)[0] | |||
| @property | |||
| def _tuple_shape(self): | |||
| return super().shape | |||
| def __repr__(self): | |||
| piece = "Tensor(" | |||
| with np.printoptions(precision=4, suppress=True): | |||
| piece += "{}".format(str(self.numpy())) | |||
| if self.dtype != np.float32: | |||
| piece += ", dtype={}".format(np.dtype(self.dtype).name) | |||
| piece += ", device={}".format(self.device) + ")" | |||
| return piece | |||
| @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") | |||
| def set_value(self, value): | |||
| if not isinstance(value, _Tensor): | |||
| value = Tensor(value, dtype=self.dtype, device=self.device) | |||
| self._reset(value) | |||
| @deprecated(version="1.0", reason="use *= 0 instead") | |||
| @@ -61,27 +101,22 @@ class Tensor(_Tensor): | |||
| def __hash__(self): | |||
| return id(self) | |||
| def __getnewargs__(self): | |||
| r""" __getnewargs__ will be called for pickle serialization or deep copy | |||
| """ | |||
| return (self.numpy(), self.dtype, self.device.logical_name) | |||
| def __getstate__(self): | |||
| r""" __getstate__ will be called for pickle serialization or deep copy | |||
| """ | |||
| state = { | |||
| "data": self.numpy(), | |||
| "device": self.device.logical_name, | |||
| "dtype": self.dtype, | |||
| "qdict": self.q_dict, | |||
| } | |||
| return state | |||
| def __setstate__(self, state): | |||
| data = state.pop("data") | |||
| logical_device = state.pop("device") | |||
| if self.dmap_callback is not None: | |||
| assert isinstance(logical_device, str) | |||
| logical_device = self.dmap_callback(logical_device) | |||
| dtype = state.pop("dtype") | |||
| self.q_dict = state.pop("qdict") | |||
| super().__init__(data, dtype=dtype, device=logical_device) | |||
| def detach(self): | |||
| r""" | |||
| @@ -89,8 +124,7 @@ class Tensor(_Tensor): | |||
| during backward gradient calcuation, i.e. its gradient is zero. | |||
| """ | |||
| Wrapper = type(self) | |||
| Tensor = type(self.__wrapped__) | |||
| return Wrapper(Tensor(self.__wrapped__._data)) | |||
| return Wrapper(self) | |||
| tensor = Tensor | |||
| @@ -0,0 +1,404 @@ | |||
| /** | |||
| * \file imperative/python/src/grad.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./grad.h" | |||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/utils/mempool.h" | |||
| namespace py = pybind11; | |||
| namespace mgb::imperative::python { | |||
| namespace { | |||
| struct GradSlotWeakPtr { | |||
| std::weak_ptr<GradFn> grad_fn; | |||
| size_t idx; | |||
| }; | |||
| } // namespace | |||
| struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> { | |||
| using Base = intrusive_list::Node<GradProducerRecord>; | |||
| GradProducerRecord() = default; | |||
| GradProducerRecord(GradProducerRecord::head_t& head) : Base(intrusive_list::after_t{}, head) {} | |||
| // GradProducerRecord(GradProducerRecord&&) = default; | |||
| // GradProducerRecord& operator=(GradProducerRecord&) = default; | |||
| // GradProducerRecord& operator=(GradProducerRecord&&) = default; | |||
| }; | |||
| struct GradSlot { | |||
| std::shared_ptr<Tensor> grad; | |||
| py::object callback; | |||
| GradProducerRecord::head_t producer_head; | |||
| }; | |||
| struct GradSlotProducerPtr : GradSlotPtr { | |||
| GradProducerRecord producer_record; | |||
| GradSlotProducerPtr() = default; | |||
| GradSlotProducerPtr(GradInfo& info) : GradSlotPtr(info), producer_record(info->producer_head) {} | |||
| }; | |||
| struct GradFn : std::enable_shared_from_this<GradFn> { | |||
| static MemPool<GradFn> pool; | |||
| std::weak_ptr<GradKey> key; | |||
| SmallVector<GradSlot> slots; | |||
| SmallVector<GradSlotProducerPtr> dsts; | |||
| SmallVector<std::shared_ptr<Tensor>> closure; | |||
| std::shared_ptr<BackwardGraphResult> backward_graph; | |||
| bool in_ref_keeper = false; | |||
| static void deleter(GradFn* ptr) { | |||
| pool.free(ptr); | |||
| } | |||
| std::shared_ptr<GradFn> make() { | |||
| return std::shared_ptr<GradFn>(pool.alloc(), &deleter); | |||
| } | |||
| void clear() { | |||
| key.reset(); | |||
| slots.clear(); | |||
| dsts.clear(); | |||
| closure.clear(); | |||
| backward_graph.reset(); | |||
| } | |||
| }; | |||
| GradSlot* GradSlotPtr::operator->() { | |||
| return &grad_fn->slots[idx]; | |||
| } | |||
| namespace { | |||
| struct BackwardGraphCache : std::unordered_map<size_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject { | |||
| std::shared_ptr<void> on_comp_node_finalize() override { | |||
| clear(); | |||
| return {}; | |||
| } | |||
| } backward_graph_cache; | |||
| std::shared_ptr<BackwardGraphResult> make_backward_graph( | |||
| ApplyContext& ctx, const apply_result_t& outputs) { | |||
| // hash | |||
| static_assert(alignof(size_t) % alignof(bool) == 0); | |||
| size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool); | |||
| alignas(alignof(size_t)) std::byte buf[buf_size]; | |||
| size_t* size_t_ptr = reinterpret_cast<size_t*>(buf); | |||
| bool* bool_ptr = reinterpret_cast<bool*>(size_t_ptr + (1 + ctx.nargs * 2)); | |||
| bool* bool_ptr0 = bool_ptr; | |||
| *(size_t_ptr++) = ctx.op->hash(); | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| *(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle()); | |||
| *(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node()); | |||
| *(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn); | |||
| } | |||
| mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) && | |||
| bool_ptr == reinterpret_cast<bool*>(buf + buf_size)); | |||
| size_t key = XXHash{}.update(buf, buf_size).digest(); | |||
| auto&& iter = backward_graph_cache.find(key); | |||
| if (iter != backward_graph_cache.end()) { | |||
| return iter->second; | |||
| } | |||
| // slow path | |||
| SmallVector<LogicalTensorDesc> inputs(ctx.nargs); | |||
| SmallVector<bool> input_requires_grad(ctx.nargs, false); | |||
| SmallVector<bool> output_has_grad(outputs.size(), true); | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| inputs[i].comp_node = ctx.args[i]->comp_node(); | |||
| inputs[i].layout.dtype = ctx.args[i]->dtype(); | |||
| input_requires_grad[i] = bool(ctx.args[i]->m_grad_info.grad_fn); | |||
| } | |||
| auto result = std::make_shared<BackwardGraphResult>( | |||
| proxy_graph_detail::make_backward_graph( | |||
| *ctx.op, inputs, input_requires_grad, output_has_grad)); | |||
| if (!result->backward) { | |||
| result.reset(); | |||
| } | |||
| backward_graph_cache.emplace(key, result); | |||
| return result; | |||
| } | |||
| } // namespace | |||
| apply_result_t apply_grad(ApplyContext& ctx) { | |||
| std::shared_ptr<GradKey> grad_key; | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| auto* tensor = ctx.args[i]; | |||
| if (tensor->m_grad_info.grad_fn) { | |||
| auto&& input_grad_key = tensor->m_grad_info.grad_fn->key.lock(); | |||
| // tensor is attached to a live GradKey | |||
| if (input_grad_key && input_grad_key->active) { | |||
| if (grad_key) { | |||
| if (grad_key != input_grad_key) { | |||
| PyErr_SetString(PyExc_NotImplementedError, "second order grad"); | |||
| throw pyext17::py_err_set(); | |||
| } | |||
| } else { | |||
| grad_key = std::move(input_grad_key); | |||
| } | |||
| } else { | |||
| // cleanup stale grad info | |||
| // under what condition? | |||
| tensor->m_grad_info = {}; | |||
| tensor->m_flags &= ~Tensor::Flags::GRAD; | |||
| } | |||
| } else { | |||
| tensor->m_flags &= ~Tensor::Flags::GRAD; | |||
| } | |||
| } | |||
| ctx.flags &= ~Tensor::Flags::GRAD; | |||
| // perform forward apply_op or trace | |||
| auto outputs = apply(ctx); | |||
| if (!grad_key) { | |||
| return outputs; | |||
| } | |||
| auto backward_graph = make_backward_graph(ctx, outputs); | |||
| if (!backward_graph) { | |||
| return outputs; | |||
| } | |||
| auto grad_fn = std::make_shared<GradFn>(); | |||
| grad_fn->key = grad_key; | |||
| grad_fn->slots.resize(outputs.size()); | |||
| grad_fn->backward_graph = std::move(backward_graph); | |||
| grad_fn->dsts.reserve(ctx.nargs); | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| if (grad_fn->backward_graph->input_has_grad[i]) { | |||
| auto& input_grad_info = ctx.args[i]->m_grad_info; | |||
| grad_fn->dsts.emplace_back(input_grad_info); | |||
| grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head); | |||
| } else { | |||
| grad_fn->dsts.emplace_back(); | |||
| } | |||
| } | |||
| auto& save_for_backward = grad_fn->backward_graph->save_for_backward; | |||
| grad_fn->closure.reserve(std::count_if(save_for_backward.begin(), save_for_backward.end(), [](bool p){return p;})); | |||
| // given op, taking gradient of output_tensor_list wrt input_tensor_list: | |||
| // | |||
| // save_for_backward[0:nargs-1]: whether input tensor requires gradient, | |||
| // i.e., whether it is in input_tensor_list | |||
| // | |||
| // save_for_backward[nargs:nargs+outputs.size()-1]: whether output tensor is | |||
| // needed to calculate gradients | |||
| // | |||
| // save_for_backward[-outputs.size():]: whether output tensor is in | |||
| // output_tensor_list | |||
| // | |||
| // Example: perform c = a * b, where a is input data, b is parameter to be | |||
| // optimized, save_for_backward = [1, 1, 0, 1] | |||
| mgb_assert(ctx.nargs + 2 * outputs.size() == save_for_backward.size()); | |||
| // record input tensors needed to take grad | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| if (save_for_backward[i]) { | |||
| grad_fn->closure.push_back(ctx.args[i]->shared_from_this()); | |||
| } | |||
| } | |||
| // record output tensors needed to take grad | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| bool requires_grad = save_for_backward[ctx.nargs + outputs.size() + i]; | |||
| if (save_for_backward[ctx.nargs + i]) { | |||
| grad_fn->closure.push_back(outputs[i]); | |||
| if (requires_grad) { | |||
| // avoid reference cycle [Tensor <-> GradFn] | |||
| outputs[i] = outputs[i]->copy(); | |||
| } | |||
| } | |||
| if (requires_grad) { | |||
| auto& grad_info = outputs[i]->m_grad_info; | |||
| grad_info.grad_fn = grad_fn; | |||
| grad_info.idx = i; | |||
| grad_info.insert_after(grad_key->free_vars_head); | |||
| outputs[i]->m_flags |= Tensor::Flags::GRAD; | |||
| } | |||
| } | |||
| // record forward history | |||
| grad_key->tape.emplace_back(grad_fn); | |||
| return outputs; | |||
| } | |||
| void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { | |||
| if (nargs != 2) { | |||
| throw py::type_error("expect 2 arguments"); | |||
| } | |||
| auto* tw = TensorWrapper::cast_safe(args[0]); | |||
| if (!tw) { | |||
| throw py::type_error("argument 1 must be Tensor"); | |||
| } | |||
| auto* tensor = tw->m_tensor.get(); | |||
| py::object callback; | |||
| if (args[1] != Py_None) { | |||
| callback = py::reinterpret_borrow<py::object>(args[1]); | |||
| } | |||
| m_key->attach(tensor, std::move(callback)); | |||
| } | |||
| //! GradKey is weakly refered by tensor->m_grad_info.grad_fn->key after attach | |||
| void GradKey::attach(Tensor* tensor, pybind11::object callback) { | |||
| if (!active) { | |||
| throw py::value_error("grad key finalized"); | |||
| } | |||
| if (tensor->m_grad_info.grad_fn) { | |||
| if (tensor->m_grad_info.grad_fn->key.lock().get() != this) { | |||
| PyErr_SetString(PyExc_NotImplementedError, "second order grad"); | |||
| throw pyext17::py_err_set(); | |||
| } | |||
| if (tensor->m_grad_info->callback) { | |||
| throw py::value_error("callback already set on this tensor"); | |||
| } | |||
| } else { | |||
| tensor->m_grad_info.idx = 0; | |||
| auto& grad_fn = tensor->m_grad_info.grad_fn; | |||
| grad_fn = std::make_shared<GradFn>(); | |||
| grad_fn->key = shared_from_this(); | |||
| grad_fn->slots.resize(1); | |||
| tensor->m_grad_info.insert_after(free_vars_head); | |||
| tensor->m_flags |= Tensor::Flags::GRAD; | |||
| } | |||
| tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback); | |||
| } | |||
| void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta) { | |||
| if (!grad) { | |||
| grad = std::forward<decltype(delta)>(delta); | |||
| return; | |||
| } | |||
| static ApplyContext ctx; | |||
| if (!ctx.op) { | |||
| ctx.op = std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD)); | |||
| ctx.nargs = 2; | |||
| } | |||
| Tensor* args[2] = {grad.get(), delta.get()}; | |||
| ctx.args = args; | |||
| ctx.flags = grad->m_flags | delta->m_flags; | |||
| grad = apply(ctx)[0]; | |||
| } | |||
| void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) { | |||
| if (!active) { | |||
| throw py::value_error("finalized"); | |||
| } | |||
| if (tensors.size() != grads.size()) { | |||
| throw py::value_error("tensor and grad size mismatch"); | |||
| } | |||
| // this GradKey is marked inactive here | |||
| active = false; | |||
| struct CleanupGuard { | |||
| GradKey* owner; | |||
| CleanupGuard(GradKey* this_) : owner(this_) {} | |||
| ~CleanupGuard() {owner->cleanup();} | |||
| } _cleanup_guard(this); | |||
| if (tape.empty() || grads.empty()) return; | |||
| PyTypeObject* pytype = Py_TYPE(grads[0]->self().ptr()); | |||
| for (size_t i = 0; i < tensors.size(); ++i) { | |||
| auto& grad_info = tensors[i]->m_tensor->m_grad_info; | |||
| if (grad_info.grad_fn && grad_info.grad_fn->key.lock().get() == this) { | |||
| grad_info->grad = grads[i]->m_tensor; | |||
| } | |||
| } | |||
| std::vector<std::shared_ptr<GradFn>> ref_keeper; | |||
| ref_keeper.reserve(tape.size()); | |||
| // back-propagation in reverse order | |||
| for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { | |||
| auto&& grad_fn = tape[k].lock(); | |||
| if (!grad_fn) continue; | |||
| if (grad_fn->backward_graph) { | |||
| for (size_t i = 0; i < grad_fn->slots.size(); ++i) { | |||
| // grad_fn->dsts correspond to input tensors during forward | |||
| // calculation, grad_fn->slots correspond to output tensors. | |||
| // condition true means the output tensor has gradient for | |||
| // back-propagation | |||
| if (grad_fn->backward_graph->save_for_backward[grad_fn->dsts.size() + grad_fn->slots.size() + i]) { | |||
| grad_fn->closure.push_back(std::move(grad_fn->slots[i].grad)); | |||
| } | |||
| } | |||
| ApplyContext ctx; | |||
| ctx.op = grad_fn->backward_graph->backward; | |||
| ctx.flags = 0; | |||
| ctx.nargs = grad_fn->closure.size(); | |||
| Tensor* args[ctx.nargs]; | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| args[i] = grad_fn->closure[i].get(); | |||
| mgb_assert(args[i]); | |||
| ctx.flags |= args[i]->m_flags; | |||
| } | |||
| ctx.args = args; | |||
| auto grads = apply(ctx); | |||
| size_t j = 0; | |||
| for (size_t i = 0; i < grad_fn->dsts.size(); ++i) { | |||
| if (grad_fn->backward_graph->input_has_grad[i]) { | |||
| auto& dst = grad_fn->dsts[i]; | |||
| // grads[j] is consumed in accum_grad | |||
| accum_grad(dst->grad, std::move(grads[j])); | |||
| ++j; | |||
| } | |||
| } | |||
| mgb_assert(j == grads.size()); | |||
| } | |||
| for (auto&& dst : grad_fn->dsts) { | |||
| if (!dst.grad_fn) continue; | |||
| if (!dst.grad_fn->in_ref_keeper) { | |||
| dst.grad_fn->in_ref_keeper = true; | |||
| ref_keeper.push_back(dst.grad_fn); | |||
| } | |||
| // grad_fn->clear will unlink current dst.producer_record | |||
| // such that if dst.producer_record.next is false, dst accumulates | |||
| // all the gradients | |||
| if (!dst.producer_record.next && dst->callback && dst->grad) { | |||
| dst->callback(TensorWrapper::make(pytype, dst->grad)); | |||
| } | |||
| } | |||
| grad_fn->clear(); | |||
| } // finish tape loop | |||
| } | |||
| void GradKey::cleanup() { | |||
| active = false; | |||
| tape.clear(); | |||
| for (intrusive_list::Iterator it(free_vars_head); it;) { | |||
| it->grad_fn.reset(); | |||
| (it++)->unlink(); | |||
| } | |||
| } | |||
| void GradKeyWrapper::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) { | |||
| m_key->backward(std::move(tensors), std::move(grads)); | |||
| } | |||
| GradKey::~GradKey() { | |||
| cleanup(); | |||
| } | |||
| } // namespace mgb::imperative::python | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * \file imperative/python/src/grad.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "./tensor.h" | |||
| #include <megbrain/utils/small_vector.h> | |||
| #include <memory> | |||
| namespace mgb::imperative::python { | |||
| apply_result_t apply_grad(ApplyContext& ctx); | |||
| struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj { | |||
| std::string name; | |||
| bool active = true; | |||
| GradInfo::head_t free_vars_head; | |||
| std::vector<std::weak_ptr<GradFn>> tape; | |||
| ~GradKey(); | |||
| void attach(Tensor* tensor, pybind11::object callback); | |||
| void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | |||
| void cleanup(); | |||
| }; | |||
| struct GradKeyWrapper { | |||
| using wrap_t = pyext17::wrap<GradKeyWrapper>; | |||
| static constexpr auto tp_name = pybind11::detail::_("GradKey"); | |||
| std::shared_ptr<GradKey> m_key; | |||
| inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {} | |||
| void attach(PyObject*const* args, size_t nargs); | |||
| void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | |||
| }; | |||
| } // namespace mgb::imperative::python | |||
| namespace pybind11::detail { | |||
| template<> struct type_caster<mgb::imperative::python::GradKeyWrapper> : mgb::imperative::python::GradKeyWrapper::wrap_t::caster {}; | |||
| } // namespace pybind11::detail | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * \file imperative/python/src/grad_info.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include <memory> | |||
| #include "./intrusive_list.h" | |||
| namespace mgb::imperative::python { | |||
| struct GradFn; | |||
| struct GradSlot; | |||
| struct GradSlotPtr { | |||
| std::shared_ptr<GradFn> grad_fn; | |||
| size_t idx; | |||
| GradSlot* operator->(); | |||
| }; | |||
| struct GradInfo : GradSlotPtr, intrusive_list::Node<GradInfo, intrusive_list::before_t> { | |||
| GradInfo() = default; | |||
| GradInfo(GradInfo&) = default; | |||
| GradInfo(GradInfo&&) = default; | |||
| GradInfo& operator=(GradInfo&) = default; | |||
| GradInfo& operator=(GradInfo&&) = default; | |||
| }; | |||
| } // namespace mgb::imperative::python | |||
| @@ -0,0 +1,227 @@ | |||
| /** | |||
| * \file imperative/python/src/intrusive_list.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/utils/metahelper.h" | |||
| namespace mgb::imperative::python::intrusive_list { | |||
| // copy policy | |||
| struct after_t {}; | |||
| struct before_t {}; | |||
| struct disable_t {}; | |||
| template <typename T> struct Tail; | |||
| // invariant: next->prev == this | |||
| template <typename T> | |||
| struct Head { | |||
| Tail<T>* next; | |||
| Head(Tail<T>* node = nullptr) : next(node) {} | |||
| Head(const Head<T>&) = delete; | |||
| Head<T>& operator=(const Head<T>&) = delete; | |||
| Head(Head<T>&& rhs) : next(rhs.next) { | |||
| rhs.next = nullptr; | |||
| if (next) { | |||
| next->prev = this; | |||
| } | |||
| } | |||
| Head<T>& operator=(Head<T>&& rhs) { | |||
| mgb_assert(!next); | |||
| next = rhs.next; | |||
| rhs.next = nullptr; | |||
| if (next) { | |||
| next->prev = this; | |||
| } | |||
| return *this; | |||
| } | |||
| ~Head() { | |||
| if (next) { | |||
| next->prev = nullptr; | |||
| } | |||
| } | |||
| }; | |||
| // invariant: prev->next == this | |||
| template <typename T> | |||
| struct Tail { | |||
| Head<T>* prev; | |||
| Tail(Head<T>* node = nullptr) : prev(node) {} | |||
| Tail(const Tail<T>&) = delete; | |||
| Tail<T>& operator=(const Tail<T>&) = delete; | |||
| Tail(Tail<T>&& rhs) : prev(rhs.prev) { | |||
| rhs.prev = nullptr; | |||
| if (prev) { | |||
| prev->next = this; | |||
| } | |||
| } | |||
| Tail<T>& operator=(Tail<T>&& rhs) { | |||
| mgb_assert(!prev); | |||
| prev = rhs.prev; | |||
| rhs.prev = nullptr; | |||
| if (prev) { | |||
| prev->next = this; | |||
| } | |||
| return *this; | |||
| } | |||
| ~Tail() { | |||
| if (prev) { | |||
| prev->next = nullptr; | |||
| } | |||
| } | |||
| }; | |||
| template <typename T, typename policy> struct Node; | |||
| template <typename T> | |||
| class Iterator { | |||
| T* ptr; | |||
| void inc() {ptr = static_cast<T*>(ptr->Head<T>::next);} | |||
| void dec() {ptr = static_cast<T*>(ptr->Head<T>::prev);} | |||
| public: | |||
| Iterator(Head<T>& head) : ptr(static_cast<T*>(head.next)) {} | |||
| Iterator(Tail<T>& tail) : ptr(static_cast<T*>(tail.prev)) {} | |||
| template<typename policy> | |||
| Iterator(Node<T, policy>& node) : ptr(static_cast<T*>(&node)) {} | |||
| T& operator*() {return *static_cast<T*>(ptr);} | |||
| T* operator->() {return static_cast<T*>(ptr);} | |||
| operator bool() {return ptr;} | |||
| bool operator==(const Iterator<T>& rhs) {return ptr == rhs.ptr;} | |||
| Iterator& operator++() {inc(); return *this;} | |||
| Iterator& operator--() {dec(); return *this;} | |||
| Iterator operator++(int) {auto ret = *this; inc(); return ret;} | |||
| Iterator operator--(int) {auto ret = *this; dec(); return ret;} | |||
| }; | |||
| // Node in a doubly linked list. Unlike std::list, nodes are not owned by a container. | |||
| // Instead, nodes may join or leave a list freely. | |||
| // NOTE: Derived classes have to explicitly declare copy / assignment as default, | |||
| // otherwise the compiler generated version would use the const T& signature, | |||
| // which is deleted. | |||
| template <typename T = void, typename policy = disable_t> | |||
| struct Node : Tail<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>>, | |||
| Head<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>> { | |||
| private: | |||
| using this_t = Node<T, policy>; | |||
| using U = std::conditional_t<std::is_same_v<T, void>, this_t, T>; | |||
| public: | |||
| using head_t = Head<U>; | |||
| using tail_t = Tail<U>; | |||
| using head_t::next; | |||
| using tail_t::prev; | |||
| Node() = default; | |||
| Node(const this_t&) = delete; | |||
| this_t& operator=(const this_t&) = delete; | |||
| //! constructed node is inserted after the input node | |||
| Node(after_t, head_t& node) : tail_t(&node), head_t(node.next) { | |||
| node.next = this; | |||
| if (next) { | |||
| next->prev = this; | |||
| } | |||
| } | |||
| //! constructed node is inserted before the input node | |||
| Node(before_t, tail_t& node) : head_t(&node), tail_t(node.prev) { | |||
| node.prev = this; | |||
| if (prev) { | |||
| prev->next = this; | |||
| } | |||
| } | |||
| Node(this_t&& rhs) : tail_t(rhs.prev), head_t(rhs.next) { | |||
| rhs.prev = nullptr; | |||
| rhs.next = nullptr; | |||
| if (prev) { | |||
| prev->next = this; | |||
| } | |||
| if (next) { | |||
| next->prev = this; | |||
| } | |||
| } | |||
| Node& operator=(this_t&& rhs) { | |||
| unlink(); | |||
| prev = rhs.prev; | |||
| next = rhs.next; | |||
| rhs.prev = nullptr; | |||
| rhs.next = nullptr; | |||
| if (prev) { | |||
| prev->next = this; | |||
| } | |||
| if (next) { | |||
| next->prev = this; | |||
| } | |||
| return *this; | |||
| } | |||
| template<typename p = policy, | |||
| typename = std::enable_if_t<std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>> | |||
| Node(this_t& rhs) : Node(policy{}, rhs) {} | |||
| template<typename p = policy, | |||
| typename = std::enable_if_t<std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>> | |||
| this_t& operator=(this_t& rhs) { | |||
| insert(policy{}, rhs); | |||
| return *this; | |||
| } | |||
| void unlink() { | |||
| if (prev) { | |||
| prev->next = next; | |||
| } | |||
| if (next) { | |||
| next->prev = prev; | |||
| } | |||
| prev = nullptr; | |||
| next = nullptr; | |||
| } | |||
| //! this node is unlinked from its list and inserted after the input node | |||
| void insert(after_t, head_t& node) { | |||
| unlink(); | |||
| prev = &node; | |||
| next = node.next; | |||
| node.next = this; | |||
| if (next) { | |||
| next->prev = this; | |||
| } | |||
| } | |||
| //! this node is unlinked from its list and inserted before the input node | |||
| void insert(before_t, tail_t& node) { | |||
| unlink(); | |||
| next = &node; | |||
| prev = node.prev; | |||
| node.prev = this; | |||
| if (prev) { | |||
| prev->next = this; | |||
| } | |||
| } | |||
| void insert_before(tail_t& node) {insert(before_t{}, node);} | |||
| void insert_after(head_t& node) {insert(after_t{}, node);} | |||
| ~Node() { | |||
| unlink(); | |||
| } | |||
| }; | |||
| } // namespace mgb::imperative::python::intrusive_list | |||
| @@ -23,7 +23,10 @@ | |||
| #include "./dispatcher.h" | |||
| #include "./tensor.h" | |||
| namespace py = pybind11; | |||
| using namespace mgb::imperative::python; | |||
| #ifndef MODULE_NAME | |||
| #define MODULE_NAME imperative_rt | |||
| @@ -68,4 +71,6 @@ PYBIND11_MODULE(MODULE_NAME, m) { | |||
| py::getattr(m, "__dict__")); | |||
| init_dispatcher(submodule(m, "dispatcher")); | |||
| init_tensor(submodule(m, "core2")); | |||
| } | |||
| @@ -15,6 +15,7 @@ | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <Python.h> | |||
| #include <pybind11/pybind11.h> | |||
| namespace pyext17 { | |||
| @@ -53,6 +54,26 @@ inline PyObject* cvt_retval(PyObject* rv) { | |||
| return cvt_retval(__VA_ARGS__); \ | |||
| } | |||
| inline int cvt_retint(int ret) { | |||
| return ret; | |||
| } | |||
| #define CVT_RET_INT(...) \ | |||
| if constexpr (std::is_same_v<decltype(__VA_ARGS__), void>) { \ | |||
| __VA_ARGS__; \ | |||
| return 0; \ | |||
| } else { \ | |||
| return cvt_retint(__VA_ARGS__); \ | |||
| } | |||
| struct py_err_set : std::exception {}; | |||
| #define HANDLE_ALL_EXC(RET) catch(py_err_set&) {return RET;} \ | |||
| catch(pybind11::error_already_set& e) {e.restore(); return RET;} \ | |||
| catch(pybind11::builtin_exception& e) {e.set_error(); return RET;} \ | |||
| catch(std::exception& e) {PyErr_SetString(PyExc_RuntimeError, e.what()); return RET;} | |||
| template <typename T> | |||
| struct wrap { | |||
| private: | |||
| @@ -111,7 +132,9 @@ private: | |||
| static PyObject* impl(PyObject* self, PyObject*) { | |||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
| CVT_RET_PYOBJ((inst->*f)()); | |||
| try { | |||
| CVT_RET_PYOBJ((inst->*f)()); | |||
| } HANDLE_ALL_EXC(nullptr) | |||
| } | |||
| }; | |||
| @@ -121,7 +144,9 @@ private: | |||
| static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { | |||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
| CVT_RET_PYOBJ((inst->*f)(args, kwargs)); | |||
| try { | |||
| CVT_RET_PYOBJ((inst->*f)(args, kwargs)); | |||
| } HANDLE_ALL_EXC(nullptr) | |||
| } | |||
| }; | |||
| @@ -132,7 +157,9 @@ private: | |||
| static PyObject* impl(PyObject* self, PyObject*const* args, Py_ssize_t nargs) { | |||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
| CVT_RET_PYOBJ((inst->*f)(args, nargs)); | |||
| try { | |||
| CVT_RET_PYOBJ((inst->*f)(args, nargs)); | |||
| } HANDLE_ALL_EXC(nullptr) | |||
| } | |||
| #else | |||
| static constexpr int flags = METH_VARARGS; | |||
| @@ -141,7 +168,9 @@ private: | |||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
| auto* arr = &PyTuple_GET_ITEM(args, 0); | |||
| auto size = PyTuple_GET_SIZE(args); | |||
| CVT_RET_PYOBJ((inst->*f)(arr, size)); | |||
| try { | |||
| CVT_RET_PYOBJ((inst->*f)(arr, size)); | |||
| } HANDLE_ALL_EXC(nullptr) | |||
| } | |||
| #endif | |||
| }; | |||
| @@ -152,7 +181,9 @@ private: | |||
| static PyObject* impl(PyObject* self, PyObject* obj) { | |||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
| CVT_RET_PYOBJ((inst->*f)(obj)); | |||
| try { | |||
| CVT_RET_PYOBJ((inst->*f)(obj)); | |||
| } HANDLE_ALL_EXC(nullptr) | |||
| } | |||
| }; | |||
| @@ -162,6 +193,55 @@ private: | |||
| return {name, (PyCFunction)M::impl, M::flags, doc}; | |||
| } | |||
| template<auto f> | |||
| struct getter { | |||
| using F = decltype(f); | |||
| static PyObject* impl(PyObject* self, void* closure) { | |||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
| try { | |||
| if constexpr (std::is_invocable_v<F, PyObject*, void*>) { | |||
| CVT_RET_PYOBJ(f(self, closure)); | |||
| } else if constexpr (std::is_invocable_v<F, T, void*>) { | |||
| CVT_RET_PYOBJ((inst->*f)(closure)); | |||
| } else if constexpr (std::is_invocable_v<F, T>) { | |||
| CVT_RET_PYOBJ((inst->*f)()); | |||
| } else { | |||
| static_assert(!std::is_same_v<F, F>); | |||
| } | |||
| } HANDLE_ALL_EXC(nullptr) | |||
| } | |||
| }; | |||
| template<auto f> | |||
| struct setter { | |||
| using F = decltype(f); | |||
| template<typename = void> | |||
| static int impl_(PyObject* self, PyObject* val, void* closure) { | |||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
| try { | |||
| if constexpr (std::is_invocable_v<F, PyObject*, PyObject*, void*>) { | |||
| CVT_RET_INT(f(self, val, closure)); | |||
| } else if constexpr (std::is_invocable_v<F, T, PyObject*, void*>) { | |||
| CVT_RET_INT((inst->*f)(val, closure)); | |||
| } else if constexpr (std::is_invocable_v<F, T, PyObject*>) { | |||
| CVT_RET_INT((inst->*f)(val)); | |||
| } else { | |||
| static_assert(!std::is_same_v<F, F>); | |||
| } | |||
| } HANDLE_ALL_EXC(-1) | |||
| } | |||
| static constexpr auto impl = []() {if constexpr (std::is_same_v<F, std::nullptr_t>) return nullptr; | |||
| else return impl_<>;}(); | |||
| }; | |||
| template<auto get, auto set = nullptr> | |||
| static constexpr PyGetSetDef make_getset_def(const char* name, const char* doc = nullptr, void* closure = nullptr) { | |||
| return {const_cast<char *>(name), getter<get>::impl, setter<set>::impl, const_cast<char *>(doc), closure}; | |||
| } | |||
| // polyfills | |||
| struct tp_vectorcall { | |||
| @@ -216,16 +296,26 @@ private: | |||
| template<typename = void> | |||
| static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) { | |||
| struct FreeGuard { | |||
| PyObject* self; | |||
| PyTypeObject* type; | |||
| ~FreeGuard() {if (self) type->tp_free(self);} | |||
| }; | |||
| auto* self = type->tp_alloc(type, 0); | |||
| FreeGuard free_guard{self, type}; | |||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
| if constexpr (has_vectorcall && tp_vectorcall::valid) { | |||
| reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>; | |||
| } | |||
| if constexpr (varkw) { | |||
| new(inst) T(args, kwargs); | |||
| } else { | |||
| new(inst) T(); | |||
| } | |||
| try { | |||
| if constexpr (varkw) { | |||
| new(inst) T(args, kwargs); | |||
| } else { | |||
| new(inst) T(); | |||
| } | |||
| } HANDLE_ALL_EXC(nullptr) | |||
| free_guard.self = nullptr; | |||
| return self; | |||
| } | |||
| @@ -250,6 +340,7 @@ private: | |||
| public: | |||
| class TypeBuilder { | |||
| std::vector<PyMethodDef> m_methods; | |||
| std::vector<PyGetSetDef> m_getsets; | |||
| PyTypeObject m_type; | |||
| bool m_finalized = false; | |||
| bool m_ready = false; | |||
| @@ -259,6 +350,13 @@ public: | |||
| throw std::runtime_error("type is already finalized"); | |||
| } | |||
| } | |||
| static const char* to_c_str(const char* s) {return s;} | |||
| template <size_t N, typename... Ts> | |||
| static const char* to_c_str(const pybind11::detail::descr<N, Ts...>& desc) { | |||
| return desc.text; | |||
| } | |||
| public: | |||
| TypeBuilder(const TypeBuilder&) = delete; | |||
| TypeBuilder& operator=(const TypeBuilder&) = delete; | |||
| @@ -266,7 +364,7 @@ public: | |||
| TypeBuilder() : m_type{PyVarObject_HEAD_INIT(nullptr, 0)} { | |||
| constexpr auto has_tp_name = HAS_MEMBER(T, tp_name); | |||
| if constexpr (has_tp_name) { | |||
| m_type.tp_name = T::tp_name; | |||
| m_type.tp_name = to_c_str(T::tp_name); | |||
| } | |||
| m_type.tp_dealloc = tp_dealloc::value; | |||
| #ifdef _Py_TPFLAGS_HAVE_VECTORCALL | |||
| @@ -291,8 +389,17 @@ public: | |||
| return m_ready; | |||
| } | |||
| bool isinstance(PyObject* op) { | |||
| return PyObject_TypeCheck(op, &m_type); | |||
| } | |||
| bool isexact(PyObject* op) { | |||
| return Py_TYPE(op) == &m_type; | |||
| } | |||
| PyObject* finalize() { | |||
| if (!m_finalized) { | |||
| m_finalized = true; | |||
| if (m_methods.size()) { | |||
| m_methods.push_back({0}); | |||
| if (m_type.tp_methods) { | |||
| @@ -301,6 +408,14 @@ public: | |||
| } | |||
| m_type.tp_methods = &m_methods[0]; | |||
| } | |||
| if (m_getsets.size()) { | |||
| m_getsets.push_back({0}); | |||
| if (m_type.tp_getset) { | |||
| PyErr_SetString(PyExc_SystemError, "tp_getset is already set"); | |||
| return nullptr; | |||
| } | |||
| m_type.tp_getset = &m_getsets[0]; | |||
| } | |||
| if (PyType_Ready(&m_type)) { | |||
| return nullptr; | |||
| } | |||
| @@ -315,12 +430,64 @@ public: | |||
| m_methods.push_back(make_meth_def<f>(name, doc)); | |||
| return *this; | |||
| } | |||
| template<auto get, auto set = nullptr> | |||
| TypeBuilder& def_getset(const char* name, const char* doc = nullptr, void* closure = nullptr) { | |||
| check_finalized(); | |||
| m_getsets.push_back(make_getset_def<get, set>(name, doc, closure)); | |||
| return *this; | |||
| } | |||
| }; | |||
| static TypeBuilder& type() { | |||
| static TypeBuilder type_helper; | |||
| return type_helper; | |||
| } | |||
| template<typename... Args> | |||
| static PyObject* cnew(Args&&... args) { | |||
| auto* pytype = type().operator->(); | |||
| auto* self = pytype->tp_alloc(pytype, 0); | |||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
| if constexpr (has_vectorcall && tp_vectorcall::valid) { | |||
| reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>; | |||
| } | |||
| new(inst) T(std::forward<Args>(args)...); | |||
| return self; | |||
| } | |||
| template<typename... Args> | |||
| static PyObject* cnew_with_type(PyTypeObject* pytype, Args&&... args) { | |||
| auto* self = pytype->tp_alloc(pytype, 0); | |||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
| if constexpr (has_vectorcall && tp_vectorcall::valid) { | |||
| reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>; | |||
| } | |||
| new(inst) T(std::forward<Args>(args)...); | |||
| return self; | |||
| } | |||
| struct caster { | |||
| static constexpr auto name = T::tp_name; | |||
| T* value; | |||
| bool load(pybind11::handle src, bool convert) { | |||
| if (wrap_t::type().isinstance(src.ptr())) { | |||
| value = reinterpret_cast<wrap_t*>(src.ptr())->inst(); | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| template <typename U> using cast_op_type = pybind11::detail::cast_op_type<U>; | |||
| operator T*() { return value; } | |||
| operator T&() { return *value; } | |||
| }; | |||
| }; | |||
| } // namespace pyext17 | |||
| @@ -328,3 +495,5 @@ public: | |||
| #undef HAS_MEMBER_TYPE | |||
| #undef HAS_MEMBER | |||
| #undef CVT_RET_PYOBJ | |||
| #undef CVT_RET_INT | |||
| #undef HANDLE_ALL_EXC | |||
| @@ -0,0 +1,257 @@ | |||
| /** | |||
| * \file imperative/python/src/tensor.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./tensor.h" | |||
| #include "./grad.h" | |||
| #include "./common.h" | |||
| #include "./numpy_dtypes.h" | |||
| #include <pybind11/numpy.h> | |||
| #include <pybind11/operators.h> | |||
| namespace py = pybind11; | |||
| namespace mgb::imperative::python { | |||
| std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | |||
| apply_result_t apply(ApplyContext& ctx) { | |||
| // emulating scalar should be put to specific op's apply, e.g., | |||
| // elementwise, reduce, typecvt. Currently it's still handled at python | |||
| // side. It could be move to C++ side if it has an impact on performance | |||
| if (ctx.flags & Tensor::Flags::SCALAR) { | |||
| // TODO: emulate scalar | |||
| } | |||
| if (ctx.flags & Tensor::Flags::GRAD) { | |||
| return apply_grad(ctx); | |||
| } | |||
| if (ctx.flags & Tensor::Flags::TRACE) { | |||
| // TODO: trace | |||
| } else { | |||
| SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs); | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| handles[i] = ctx.args[i]->m_handle.get(); | |||
| } | |||
| auto output_handles = interpreter_for_py->apply_op(ctx.op, handles); | |||
| apply_result_t outputs; | |||
| outputs.reserve(output_handles.size()); | |||
| for (auto h : output_handles) { | |||
| outputs.emplace_back(std::make_shared<Tensor>(h)); | |||
| } | |||
| return outputs; | |||
| } | |||
| mgb_assert(0); | |||
| } | |||
| PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */) { | |||
| try { | |||
| // if (kwnames && PyTuple_GET_SIZE(kwnames)) { | |||
| // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed"); | |||
| // return nullptr; | |||
| // } | |||
| if (!nargs) { | |||
| PyErr_SetString(PyExc_TypeError, "expect Op"); | |||
| return nullptr; | |||
| } | |||
| auto* op = args[0]; | |||
| if (!strcmp(op->ob_type->tp_base->tp_name,"PodOpVisitor") || !strcmp(op->ob_type->tp_base->tp_name,"IndexingOpBase")){ | |||
| op = PyObject_CallMethod(op,"to_c",""); | |||
| } | |||
| PyTypeObject* pytype = args[1]->ob_type; | |||
| ++args; | |||
| --nargs; | |||
| ApplyContext ctx; | |||
| ctx.flags = 0; | |||
| ctx.op = py::handle(op).cast<std::shared_ptr<OpDef>>(); | |||
| SmallVector<Tensor*, 64> tensors(nargs); | |||
| ctx.args = &tensors[0]; | |||
| ctx.nargs = nargs; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| TensorWrapper* tw = TensorWrapper::cast_safe(args[i]); | |||
| if (!tw) { | |||
| PyErr_SetString(PyExc_TypeError, "expect Tensor"); | |||
| return nullptr; | |||
| } | |||
| auto* t = tensors[i] = tw->m_tensor.get(); | |||
| ctx.flags |= t->m_flags; | |||
| } | |||
| // TODO: set TRACE flag | |||
| auto outputs = apply(ctx); | |||
| size_t nout = outputs.size(); | |||
| auto ret = py::tuple(nout); | |||
| for (size_t i = 0; i < nout; ++i) { | |||
| ret[i] = TensorWrapper::make(pytype, std::move(outputs[i])); | |||
| } | |||
| return ret.release().ptr(); | |||
| } catch (std::exception& e) { | |||
| PyErr_SetString(PyExc_RuntimeError, e.what()); | |||
| return nullptr; | |||
| } | |||
| } | |||
| TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
| if (kwargs && PyDict_Size(kwargs)) { | |||
| throw py::type_error("keyword argument not allowed"); | |||
| } | |||
| auto nargs = PyTuple_Size(args); | |||
| auto tup = py::reinterpret_borrow<py::tuple>(args); | |||
| if (nargs == 0) { | |||
| throw py::type_error("too few arguments"); | |||
| } | |||
| if (auto* t = cast_safe(tup[0].ptr())) { | |||
| if (nargs > 1) { | |||
| throw py::type_error("expect 1 argument"); | |||
| } | |||
| m_tensor = t->m_tensor; | |||
| } else { | |||
| if (nargs != 3) { | |||
| throw py::type_error("expect 3 arguments"); | |||
| } | |||
| py::detail::loader_life_support life_sup; // required to cast DType | |||
| auto data = tup[0].cast<py::array>(); | |||
| DType dtype = tup[1].cast<DType>(); | |||
| CompNode cn = tup[2].cast<CompNode>(); | |||
| interpreter::Interpreter::Handle handle; | |||
| constexpr auto size_threshhold = TensorShape::MAX_NDIM; | |||
| if (data.size() > size_threshhold) { | |||
| handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype)); | |||
| } else { | |||
| HostTensorND ret(cn); | |||
| handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); | |||
| } | |||
| m_tensor = std::make_shared<Tensor>(handle); | |||
| if (data.ndim() == 0) { | |||
| m_tensor->m_flags |= Tensor::Flags::SCALAR; | |||
| } | |||
| } | |||
| } | |||
| PyObject* TensorWrapper::shape() { | |||
| if (m_tensor->m_flags & Tensor::Flags::SCALAR) { | |||
| return PyTuple_New(0); | |||
| } | |||
| auto&& shape = m_tensor->shape(); | |||
| if (!shape.ndim) { | |||
| Py_RETURN_NONE; | |||
| } | |||
| py::tuple ret(shape.ndim); | |||
| for (size_t i = 0; i < shape.ndim; ++i) { | |||
| ret[i] = shape[i]; | |||
| } | |||
| return ret.release().ptr(); | |||
| } | |||
| PyObject* TensorWrapper::dtype() { | |||
| return py::cast(m_tensor->dtype()).release().ptr(); | |||
| } | |||
| PyObject* TensorWrapper::device() { | |||
| return py::cast(m_tensor->comp_node()).release().ptr(); | |||
| } | |||
| PyObject* TensorWrapper::numpy() { | |||
| auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get()); | |||
| auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); | |||
| if (!arr) return nullptr; | |||
| if (m_tensor->m_flags & Tensor::Flags::SCALAR) { | |||
| mgb_assert(PyArray_Check(arr.ptr())); | |||
| return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr())); | |||
| } | |||
| return arr.release().ptr(); | |||
| } | |||
| void TensorWrapper::reset(PyObject* tensor) { | |||
| TensorWrapper* t = TensorWrapper::cast_safe(tensor); | |||
| if (!t) { | |||
| throw py::type_error("expect Tensor"); | |||
| } | |||
| m_tensor = t->m_tensor; | |||
| } | |||
| PyObject* TensorWrapper::isscalar() { | |||
| if(m_tensor->m_flags & Tensor::Flags::SCALAR) { | |||
| Py_RETURN_TRUE; | |||
| } else { | |||
| Py_RETURN_FALSE; | |||
| } | |||
| } | |||
| void TensorWrapper::setscalar() { | |||
| m_tensor->m_flags |= Tensor::Flags::SCALAR; | |||
| } | |||
| struct TensorWeakRef { | |||
| std::weak_ptr<Tensor> wptr; | |||
| TensorWeakRef(const TensorWrapper& tw) : wptr(tw.m_tensor) {} | |||
| py::object operator()() { | |||
| if (auto p = wptr.lock()) { | |||
| return TensorWrapper::make(p); | |||
| } | |||
| return py::none(); | |||
| } | |||
| }; | |||
| void init_tensor(py::module m) { | |||
| interpreter_for_py = interpreter::Interpreter::inst().create_channel(); | |||
| auto* tensor_type = TensorWrapper::wrap_t::type() | |||
| .def<&TensorWrapper::numpy>("numpy") | |||
| .def_getset<&TensorWrapper::shape>("shape") | |||
| .def_getset<&TensorWrapper::dtype>("dtype") | |||
| .def_getset<&TensorWrapper::device>("device") | |||
| .def<&TensorWrapper::reset>("_reset") | |||
| .def<&TensorWrapper::isscalar>("isscalar") | |||
| .def<&TensorWrapper::setscalar>("setscalar") | |||
| .finalize(); | |||
| if (!tensor_type) throw py::error_already_set(); | |||
| py::setattr(m, "Tensor", tensor_type); | |||
| py::class_<TensorWeakRef>(m, "TensorWeakRef") | |||
| .def(py::init<const TensorWrapper&>()) | |||
| .def("__call__", &TensorWeakRef::operator()); | |||
| static PyMethodDef apply_def{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr}; | |||
| auto* apply_func = PyCFunction_NewEx(&apply_def, nullptr, nullptr); | |||
| if (!apply_func) throw py::error_already_set(); | |||
| py::setattr(m, "apply", apply_func); | |||
| py::handle grad_key_type = GradKeyWrapper::wrap_t::type() | |||
| .def<&GradKeyWrapper::attach>("attach") | |||
| .finalize(); | |||
| if (!grad_key_type) throw py::error_already_set(); | |||
| py::setattr(m, "GradKey", grad_key_type); | |||
| py::setattr(m, "backward", py::cpp_function(&GradKeyWrapper::backward)); | |||
| } | |||
| } // namespace mgb::imperative::python | |||
| @@ -0,0 +1,157 @@ | |||
| /** | |||
| * \file imperative/python/src/tensor.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <variant> | |||
| #include "megbrain/imperative/interpreter.h" | |||
| #include "pybind11/pybind11.h" | |||
| #include "./pyext17.h" | |||
| namespace mgb::imperative::python { | |||
| template<typename T, typename B = pybind11::object> | |||
| struct ObjectPtr : B { | |||
| using B::B; | |||
| T& operator*() {return reinterpret_cast<T&>(*B::ptr());} | |||
| T* operator->() {return reinterpret_cast<T*>(B::ptr());} | |||
| }; | |||
| } // namespace mgb::imperative::python | |||
| #include "./grad_info.h" // for struct GradInfo | |||
| namespace mgb::imperative::python { | |||
| struct TraceInfo { | |||
| }; | |||
| extern std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | |||
| class SharedHandle { | |||
| using Handle = interpreter::Interpreter::Handle; | |||
| static_assert(std::is_pointer_v<Handle>); | |||
| std::shared_ptr<std::remove_pointer_t<Handle>> holder; | |||
| public: | |||
| inline explicit SharedHandle(Handle handle) : holder(handle, [](auto* h){ | |||
| interpreter_for_py->del(h); | |||
| }) {} | |||
| SharedHandle(const SharedHandle&) = default; | |||
| SharedHandle& operator=(const SharedHandle&) = default; | |||
| SharedHandle(SharedHandle&&) = default; | |||
| SharedHandle& operator=(SharedHandle&&) = default; | |||
| inline Handle get() {return holder.get();} | |||
| }; | |||
| struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||
| using flags_t = uint64_t; | |||
| struct Flags { | |||
| static constexpr flags_t SCALAR = 1; | |||
| static constexpr flags_t GRAD = 1 << 1; | |||
| static constexpr flags_t TRACE = 1 << 2; | |||
| }; | |||
| flags_t m_flags = 0; | |||
| GradInfo m_grad_info; | |||
| TraceInfo m_trace_info; | |||
| SharedHandle m_handle; | |||
| using Handle = interpreter::Interpreter::Handle; | |||
| inline explicit Tensor(Handle handle) : m_handle(handle) {} | |||
| inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)) {} | |||
| ~Tensor() = default; | |||
| inline std::shared_ptr<Tensor> copy() { | |||
| auto ret = std::make_shared<Tensor>(m_handle); | |||
| ret->m_flags = m_flags; | |||
| ret->m_grad_info = m_grad_info; | |||
| ret->m_trace_info = m_trace_info; | |||
| return ret; | |||
| } | |||
| inline DType dtype() {return interpreter_for_py->get_dtype(m_handle.get());} | |||
| inline CompNode comp_node() {return interpreter_for_py->get_device(m_handle.get());} | |||
| inline TensorShape shape() {return interpreter_for_py->get_shape(m_handle.get());} | |||
| }; | |||
| struct TensorWrapper { | |||
| std::shared_ptr<Tensor> m_tensor; | |||
| inline TensorWrapper(std::shared_ptr<Tensor> tensor = {}) : m_tensor(std::move(tensor)) {} | |||
| TensorWrapper(PyObject* args, PyObject* kwargs); | |||
| ~TensorWrapper() = default; | |||
| static constexpr auto tp_name = pybind11::detail::_("Tensor"); | |||
| using wrap_t = pyext17::wrap<TensorWrapper>; | |||
| friend wrap_t; | |||
| inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast<wrap_t*>(op)->inst();} | |||
| inline static TensorWrapper* cast_safe(PyObject* op) { | |||
| if (!wrap_t::type().isinstance(op)) return nullptr; | |||
| return cast(op); | |||
| } | |||
| inline ObjectPtr<TensorWrapper, pybind11::handle> self() {return wrap_t::pycast(this);} | |||
| template <typename... Args> | |||
| static ObjectPtr<Tensor> make(Args&&... args) { | |||
| auto* op = wrap_t::cnew(std::forward<Args>(args)...); | |||
| return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op); | |||
| } | |||
| template <typename... Args> | |||
| static ObjectPtr<Tensor> make(PyTypeObject* pytype, Args&&... args) { | |||
| auto* op = wrap_t::cnew_with_type(pytype,std::forward<Args>(args)...); | |||
| return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op); | |||
| } | |||
| PyObject* shape(); | |||
| PyObject* dtype(); | |||
| PyObject* device(); | |||
| PyObject* numpy(); | |||
| void reset(PyObject*); | |||
| PyObject* isscalar(); | |||
| void setscalar(); | |||
| }; | |||
| PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); | |||
| struct ApplyContext { | |||
| Tensor::flags_t flags; | |||
| std::shared_ptr<OpDef> op; | |||
| Tensor*const* args; | |||
| size_t nargs; | |||
| }; | |||
| using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>; | |||
| apply_result_t apply(ApplyContext& ctx); | |||
| void init_tensor(pybind11::module); | |||
| } // namespace mgb::imperative::python | |||
| namespace pybind11::detail { | |||
| template<> struct type_caster<mgb::imperative::python::TensorWrapper> : mgb::imperative::python::TensorWrapper::wrap_t::caster {}; | |||
| } // namespace pybind11::detail | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file imperative/python/src/trace.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| namespace mgb::imperative::python { | |||
| struct TraceInfo { | |||
| }; | |||
| } // namespace mgb::imperative::python | |||
| @@ -12,6 +12,7 @@ def test_basic(): | |||
| config_async_level(3) | |||
| @pytest.mark.skip | |||
| def test_level1_infer_value(): | |||
| config_async_level(1) | |||
| a = mge.tensor([[1, 2], [2, 3], [3, 4]], dtype="float32") | |||
| @@ -22,6 +23,7 @@ def test_level1_infer_value(): | |||
| d = F.reshape(a, c) | |||
| @pytest.mark.skip | |||
| def test_level1_infer_shape_with_unknown(): | |||
| config_async_level(2) | |||
| a = mge.tensor([[1, 2, 2, 3]], dtype="float32") | |||
| @@ -16,12 +16,11 @@ import pytest | |||
| import megengine as mge | |||
| import megengine.distributed as dist | |||
| import megengine.functional as F | |||
| from megengine.core._imperative_rt import TensorAttr, imperative | |||
| from megengine.core._imperative_rt import TensorAttr, core2, imperative | |||
| from megengine.core._imperative_rt.core2 import TensorWeakRef, apply | |||
| from megengine.core._imperative_rt.imperative import sync | |||
| from megengine.core.autodiff.grad import Grad | |||
| from megengine.core.ops.builtin import Elemwise | |||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
| from megengine.core.tensor.tensor import Tensor, apply | |||
| from megengine.core.tensor.tensor_wrapper import TensorWrapper | |||
| from megengine.distributed.helper import get_device_count_by_fork | |||
| from megengine.functional.distributed import remote_recv, remote_send | |||
| @@ -43,11 +42,11 @@ relu = _elwise(Elemwise.Mode.RELU) | |||
| def as_tensor(x): | |||
| return Tensor(as_raw_tensor(x, device=mge.device.get_default_device())) | |||
| return mge.Tensor(x) | |||
| def save_to(self, name="grad"): | |||
| def callback(tensor, grad): | |||
| def callback(grad): | |||
| setattr(self, name, grad) | |||
| return callback | |||
| @@ -136,14 +135,14 @@ def test_2nd_grad(): | |||
| def test_grad_with_tensor_wrapper(): | |||
| x_np = np.random.rand(10).astype("float32") | |||
| x = TensorWrapper(x_np) | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| y = mul(x, x) | |||
| y = mul(y, y) | |||
| grad(y, TensorWrapper(np.ones_like(x_np))) | |||
| grad(y, mge.Tensor(np.ones_like(x_np))) | |||
| np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | |||
| @@ -162,8 +161,8 @@ def test_release(): | |||
| finally: | |||
| gc.enable() | |||
| x = TensorWrapper([0.0]) | |||
| dy = TensorWrapper(np.ones_like(x.numpy())) | |||
| x = mge.Tensor([0.0]) | |||
| dy = mge.Tensor(np.ones_like(x.numpy())) | |||
| @check | |||
| def _(): | |||
| @@ -173,25 +172,25 @@ def test_release(): | |||
| @check | |||
| def _(): | |||
| with Grad().wrt(x) as g: | |||
| with Grad().wrt(x): | |||
| pass | |||
| @check | |||
| def _(): | |||
| with Grad().wrt(x) as g: | |||
| with Grad().wrt(x): | |||
| y = x * x | |||
| def test_grad_inplace(): | |||
| x_np = np.random.rand(10).astype("float32") | |||
| x = TensorWrapper(x_np) | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| y = mul(x, x) | |||
| y *= y | |||
| grad(y, TensorWrapper(np.ones_like(x_np))) | |||
| grad(y, mge.Tensor(np.ones_like(x_np))) | |||
| np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | |||
| @@ -199,16 +198,16 @@ def test_elemwise_add(): | |||
| x_np = np.random.rand(10).astype("float32") | |||
| y_np = np.random.rand(10, 10).astype("float32") | |||
| dz_np = np.random.rand(10, 10).astype("float32") | |||
| x = TensorWrapper(x_np) | |||
| y = TensorWrapper(y_np) | |||
| dz = TensorWrapper(dz_np) | |||
| x = mge.Tensor(x_np) | |||
| y = mge.Tensor(y_np) | |||
| dz = mge.Tensor(dz_np) | |||
| refs = {} | |||
| def f(x, y): | |||
| x = x * 2 | |||
| refs["x"] = weakref.ref(x.__wrapped__) | |||
| refs["y"] = weakref.ref(y.__wrapped__) | |||
| refs["x"] = TensorWeakRef(x) | |||
| refs["y"] = TensorWeakRef(y) | |||
| return x + y | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| @@ -226,14 +225,14 @@ def test_elemwise_add(): | |||
| def test_elemwise_relu(): | |||
| x_np = [1.0, -1.0] | |||
| dz_np = [1.0] | |||
| x = TensorWrapper(x_np) | |||
| dz = TensorWrapper(dz_np) | |||
| x = mge.Tensor(x_np) | |||
| dz = mge.Tensor(dz_np) | |||
| refs = {} | |||
| def f(x): | |||
| x = x * 2 | |||
| refs["x"] = weakref.ref(x.__wrapped__) | |||
| refs["x"] = TensorWeakRef(x) | |||
| return relu(x) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| @@ -258,7 +257,7 @@ def test_elemwise_relu_backward_fn(): | |||
| def test_reshape(): | |||
| x_np = np.random.rand(2, 5).astype("float32") | |||
| x = TensorWrapper(x_np) | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| y = x.reshape(5, 2) | |||
| @@ -269,7 +268,7 @@ def test_reshape(): | |||
| def test_subtensor(): | |||
| x_np = np.random.rand(3, 3).astype("float32") | |||
| x = TensorWrapper(x_np) | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| y = x[1:-1, :2] | |||
| @@ -282,7 +281,7 @@ def test_subtensor(): | |||
| def test_IndexingMultiAxisVec(): | |||
| x_np = np.random.rand(3, 3).astype("float32") | |||
| x = TensorWrapper(x_np) | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| y = x[[0, 2], [0, 2]] | |||
| @@ -295,7 +294,7 @@ def test_IndexingMultiAxisVec(): | |||
| def test_AxisAddRemove(): | |||
| x_np = np.random.rand(1, 5).astype("float32") | |||
| x = TensorWrapper(x_np) | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| y = F.squeeze(F.expand_dims(x, 2), 0) | |||
| @@ -308,7 +307,7 @@ def test_AxisAddRemove(): | |||
| def test_Broadcast(): | |||
| x_np = np.random.rand(3, 3, 1).astype("float32") | |||
| x = TensorWrapper(x_np) | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| y = F.broadcast_to(x, (3, 3, 10)) | |||
| @@ -319,7 +318,7 @@ def test_Broadcast(): | |||
| def test_Reduce_sum(): | |||
| x_np = np.random.rand(3, 3).astype("float32") | |||
| x = TensorWrapper(x_np) | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| y = x.sum(axis=0) | |||
| @@ -330,7 +329,7 @@ def test_Reduce_sum(): | |||
| def test_Reduce_mean(): | |||
| x_np = np.random.rand(3, 3).astype("float32") | |||
| x = TensorWrapper(x_np) | |||
| x = mge.Tensor(x_np) | |||
| grad = Grad().wrt(x, callback=save_to(x)) | |||
| y = x.mean(axis=0) | |||
| @@ -11,30 +11,29 @@ import collections | |||
| import numpy as np | |||
| import pytest | |||
| import megengine.core.tensor.raw_tensor | |||
| import megengine | |||
| import megengine.tensor as Tensor | |||
| from megengine.core._imperative_rt.core2 import apply | |||
| from megengine.core._trace_option import use_symbolic_shape | |||
| from megengine.core.ops import builtin | |||
| from megengine.core.tensor import Tensor | |||
| from megengine.core.tensor.core import apply | |||
| from megengine.core.tensor.raw_tensor import RawTensor, as_raw_tensor | |||
| def cvt_to_shape_desc(val, inpvar, config=None): | |||
| def as_tensor(val, device): | |||
| assert device is not None, "can not infer device" | |||
| # TODO: should copy to appropriate device | |||
| val = as_raw_tensor(val, device=device) | |||
| val = Tensor(val, device=device) | |||
| return val | |||
| device = None | |||
| if inpvar is not None: | |||
| assert isinstance(inpvar, RawTensor) | |||
| assert isinstance(inpvar, Tensor) | |||
| device = device or inpvar.device | |||
| if config is not None: | |||
| device = device or config.device | |||
| if isinstance(val, RawTensor): | |||
| if isinstance(val, Tensor): | |||
| return as_tensor(val, device) | |||
| if not isinstance(val, collections.abc.Iterable): | |||
| @@ -43,7 +42,7 @@ def cvt_to_shape_desc(val, inpvar, config=None): | |||
| components = [] | |||
| on_host = True | |||
| for i in val: | |||
| if isinstance(i, RawTensor): | |||
| if isinstance(i, Tensor): | |||
| on_host = False | |||
| device = device or i.device | |||
| else: | |||
| @@ -62,7 +61,7 @@ def cvt_to_shape_desc(val, inpvar, config=None): | |||
| return as_tensor(shape, device) | |||
| for idx, v in enumerate(components): | |||
| if not isinstance(v, RawTensor): | |||
| if not isinstance(v, Tensor): | |||
| vi = int(v) | |||
| assert vi == v, "could not convert {} to int".format(v) | |||
| v = vi | |||
| @@ -95,7 +94,7 @@ def canonize_inputs(inputs, *, config): | |||
| # and is called with concat([a, b])) | |||
| inputs = inputs[0] | |||
| if isinstance(inputs, RawTensor): | |||
| if isinstance(inputs, Tensor): | |||
| return [inputs] | |||
| old_inputs = inputs | |||
| @@ -103,7 +102,7 @@ def canonize_inputs(inputs, *, config): | |||
| get_comp_node = None | |||
| need_cvt = False | |||
| for i in old_inputs: | |||
| if isinstance(i, RawTensor): | |||
| if isinstance(i, Tensor): | |||
| get_comp_node = lambda cn=i.device: cn | |||
| else: | |||
| need_cvt = True | |||
| @@ -117,8 +116,8 @@ def canonize_inputs(inputs, *, config): | |||
| return config.comp_node | |||
| for idx, var in enumerate(inputs): | |||
| if not isinstance(var, RawTensor): | |||
| var = as_raw_tensor(var) | |||
| if not isinstance(var, Tensor): | |||
| var = Tensor(var) | |||
| inputs[idx] = var | |||
| return inputs | |||
| @@ -131,15 +130,15 @@ def invoke_op(op, inputs_, cvt_inputs=canonize_inputs): | |||
| def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
| assert isinstance(inp, RawTensor) | |||
| assert isinstance(inp, Tensor) | |||
| if not isinstance(tuple_val, tuple): | |||
| tuple_val = (tuple_val,) | |||
| def as_tensor(v): | |||
| if not isinstance(v, RawTensor): | |||
| if not isinstance(v, Tensor): | |||
| vi = np.ascontiguousarray(v, dtype=np.int32) | |||
| assert np.abs(vi - v).max() == 0, "bad index: {!r}".format(v) | |||
| v = as_raw_tensor(vi) | |||
| v = Tensor(vi) | |||
| return v | |||
| new_axes = [] | |||
| @@ -275,14 +274,14 @@ def batched_incr_mesh_indexing(input, value, tuple_val): | |||
| def test_transpose(): | |||
| x = np.arange(10).reshape(2, 5).astype("int32") | |||
| xx = as_raw_tensor(x) | |||
| xx = Tensor(x) | |||
| (yy,) = transpose(xx, pattern=[1, -1, 0]) | |||
| np.testing.assert_equal(np.expand_dims(x.transpose(), axis=1), yy.numpy()) | |||
| def test_broadcast(): | |||
| x = np.arange(10).reshape(1, 10).astype("int32") | |||
| xx = as_raw_tensor(x) | |||
| xx = Tensor(x) | |||
| (yy,) = broadcast(xx, (10, 10)) | |||
| np.testing.assert_equal(np.repeat(x, 10, 0), yy.numpy()) | |||
| @@ -290,7 +289,7 @@ def test_broadcast(): | |||
| def test_subtensor(): | |||
| x = np.arange(25).reshape(5, 5).astype("int32") | |||
| d = np.arange(2).astype("int32") | |||
| xx = as_raw_tensor(x) | |||
| xx = Tensor(x) | |||
| (yy0,) = subtensor(xx, (slice(0, 4, 2), 3)) | |||
| (yy1,) = set_subtensor(xx, d, (slice(0, 4, 2), 3)) | |||
| (yy2,) = incr_subtensor(xx, d, (slice(0, 4, 2), 3)) | |||
| @@ -309,7 +308,7 @@ def test_subtensor(): | |||
| def test_advance_indexing(): | |||
| x = np.arange(25).reshape(5, 5).astype("int32") | |||
| d = np.arange(15).reshape(3, 5).astype("int32") | |||
| xx = as_raw_tensor(x) | |||
| xx = Tensor(x) | |||
| (yy0,) = advance_indexing(xx, ((0, 4, 2), slice(None, None, None))) | |||
| (yy1,) = set_advance_indexing(xx, d, ((0, 4, 2), slice(None, None, None))) | |||
| (yy2,) = incr_advance_indexing(xx, d, ((0, 4, 2), slice(None, None, None))) | |||
| @@ -328,7 +327,7 @@ def test_advance_indexing(): | |||
| def test_mesh_indexing(): | |||
| x = np.arange(25).reshape(5, 5).astype("int32") | |||
| d = np.arange(6).reshape(3, 2).astype("int32") | |||
| xx = as_raw_tensor(x) | |||
| xx = Tensor(x) | |||
| (yy0,) = mesh_indexing(xx, (slice(0, 5, 2), (1, 3))) | |||
| (yy1,) = set_mesh_indexing(xx, d, (slice(0, 5, 2), (1, 3))) | |||
| (yy2,) = incr_mesh_indexing(xx, d, (slice(0, 5, 2), (1, 3))) | |||
| @@ -355,7 +354,7 @@ def test_mesh_indexing(): | |||
| def test_batched_mesh_indexing(): | |||
| x = np.arange(24).reshape(2, 3, 4).astype("int32") | |||
| d = np.arange(12).reshape(2, 2, 3).astype("int32") | |||
| xx = as_raw_tensor(x) | |||
| xx = Tensor(x) | |||
| s = [(0, 1, 2), (1, 2, 3)] | |||
| (yy0,) = batched_mesh_indexing(xx, (slice(None, None, None), [(0, 2)] * 2, s)) | |||
| (yy1,) = batched_set_mesh_indexing( | |||
| @@ -9,12 +9,12 @@ | |||
| import numpy as np | |||
| from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 | |||
| from megengine.core.tensor.tensor_wrapper import TensorWrapper | |||
| from megengine.tensor import Tensor | |||
| def test_basic(): | |||
| x_np = np.random.rand(10).astype("float32") | |||
| x = TensorWrapper(x_np) | |||
| x = Tensor(x_np) | |||
| y = x * x | |||
| y_np = y.numpy() | |||
| np.testing.assert_almost_equal(y_np, x_np * x_np) | |||
| @@ -22,15 +22,15 @@ def test_basic(): | |||
| def test_literal_arith(): | |||
| x_np = np.random.rand(10).astype("float32") | |||
| x = TensorWrapper(x_np) | |||
| x = Tensor(x_np) | |||
| y = x * 2 | |||
| y_np = y.numpy() | |||
| np.testing.assert_almost_equal(y_np, x_np * 2) | |||
| def test_matmul(): | |||
| A = TensorWrapper(np.random.rand(5, 7).astype("float32")) | |||
| B = TensorWrapper(np.random.rand(7, 10).astype("float32")) | |||
| A = Tensor(np.random.rand(5, 7).astype("float32")) | |||
| B = Tensor(np.random.rand(7, 10).astype("float32")) | |||
| C = A @ B | |||
| np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6) | |||
| @@ -38,7 +38,7 @@ def test_matmul(): | |||
| def test_reduce(): | |||
| def test_x(x_np): | |||
| for m in ["sum", "prod", "min", "max", "mean"]: | |||
| x = TensorWrapper(x_np) | |||
| x = Tensor(x_np) | |||
| y = getattr(x, m)(axis=-1, keepdims=True) | |||
| np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6) | |||
| @@ -49,7 +49,7 @@ def test_reduce(): | |||
| def test_set_subtensor(): | |||
| x = TensorWrapper([1, 2, 3]) | |||
| x = Tensor([1, 2, 3]) | |||
| x[:] = [1, 1, 1] | |||
| np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6) | |||
| x[[0, 2]] = [3, 2] | |||
| @@ -60,7 +60,7 @@ def test_set_subtensor(): | |||
| def test_computing_with_numpy_array(): | |||
| x = np.array([1, 2, 3], dtype=np.int32) | |||
| xx = TensorWrapper(x, device="cpu0") | |||
| xx = Tensor(x, device="cpu0") | |||
| y = np.array([1, 0, 3], dtype=np.int32) | |||
| assert np.add(xx, y).device == xx.device | |||
| np.testing.assert_equal(np.add(xx, y).numpy(), np.add(x, y)) | |||
| @@ -70,12 +70,12 @@ def test_computing_with_numpy_array(): | |||
| def test_transpose(): | |||
| x = np.random.rand(2, 5).astype("float32") | |||
| xx = TensorWrapper(x) | |||
| xx = Tensor(x) | |||
| np.testing.assert_almost_equal(xx.T.numpy(), x.T) | |||
| def test_as_type(): | |||
| x = TensorWrapper([1, 2, 3], dtype=np.float32) | |||
| x = Tensor([1, 2, 3], dtype=np.float32) | |||
| y = x.astype(qint8(0.1)) | |||
| np.testing.assert_almost_equal(get_scale(y.dtype), 0.1) | |||
| z = y.astype(qint8(0.2)) | |||
| @@ -312,7 +312,7 @@ def test_device(): | |||
| np.testing.assert_almost_equal(y1.numpy(), y2.numpy()) | |||
| y3 = F.eye(x.shape, dtype="float32", device="xpux") | |||
| y4 = F.eye(x.shape, dtype="float32", device=x.device.to_c()) | |||
| y4 = F.eye(x.shape, dtype="float32", device=x.device) | |||
| np.testing.assert_almost_equal(y3.numpy(), y4.numpy()) | |||
| y5 = F.full((3, 2), 4, device=x.device) | |||
| @@ -14,7 +14,7 @@ | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "./op_trait.h" | |||
| #include "./proxy_graph_detail.h" | |||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| @@ -13,7 +13,7 @@ | |||
| #if MGB_ENABLE_OPR_MM | |||
| #include "../op_trait.h" | |||
| #include "../proxy_graph_detail.h" | |||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||
| #include "megbrain/opr/mm_handler.h" | |||
| #include "megbrain/utils/hash.h" | |||
| #endif // MGB_ENABLE_OPR_MM | |||
| @@ -13,7 +13,7 @@ | |||
| #if MGB_ENABLE_OPR_MM | |||
| #include "../op_trait.h" | |||
| #include "../proxy_graph_detail.h" | |||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||
| #include "megbrain/opr/io_remote.h" | |||
| #include "megbrain/opr/mm_handler.h" | |||
| #endif // MGB_ENABLE_OPR_MM | |||
| @@ -13,7 +13,7 @@ | |||
| #include "megbrain/serialization/opr_load_dump.h" | |||
| #include "../op_trait.h" | |||
| #include "../proxy_graph_detail.h" | |||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| @@ -10,7 +10,7 @@ | |||
| */ | |||
| #include "./proxy_graph.h" | |||
| #include "./proxy_graph_detail.h" | |||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * \file imperative/src/impl/proxy_graph_detail.h | |||
| * \file imperative/src/include/megbrain/imperative/proxy_graph_detail.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||