GitOrigin-RevId: 2dd4e460ac
tags/v1.2.0
| @@ -301,7 +301,7 @@ class GradManager: | |||||
| if tensor is None: | if tensor is None: | ||||
| return | return | ||||
| def callback(_, grad, callbacks=spec.callbacks): | |||||
| def callback(grad, callbacks=spec.callbacks): | |||||
| for cb in callbacks: | for cb in callbacks: | ||||
| grad = cb(tensor, grad) | grad = cb(tensor, grad) | ||||
| self._gradients[id(tensor)] = grad | self._gradients[id(tensor)] = grad | ||||
| @@ -16,6 +16,7 @@ import numpy as np | |||||
| import megengine as mge | import megengine as mge | ||||
| from .._imperative_rt import core2 | |||||
| from ..ops.builtin import Elemwise, OpDef, RemoteSend | from ..ops.builtin import Elemwise, OpDef, RemoteSend | ||||
| from ..ops.special import Const | from ..ops.special import Const | ||||
| from ..tensor.core import TensorBase, TensorWrapperBase, apply | from ..tensor.core import TensorBase, TensorWrapperBase, apply | ||||
| @@ -418,3 +419,28 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): | |||||
| @apply.register() | @apply.register() | ||||
| def _(op: Const, *_: typing.Optional[Tracer]): | def _(op: Const, *_: typing.Optional[Tracer]): | ||||
| return None | return None | ||||
| class Grad: | |||||
| def __init__(self): | |||||
| self._impl = core2.GradKey() | |||||
| def wrt(self, *tensors, callback=None): | |||||
| for x in tensors: | |||||
| self._impl.attach(x, callback) | |||||
| return self | |||||
| def __call__(self, ys, dys): | |||||
| from collections.abc import Sequence | |||||
| if not isinstance(ys, Sequence): | |||||
| ys = [ys] | |||||
| if not isinstance(dys, Sequence): | |||||
| dys = [dys] | |||||
| core2.backward(self._impl, ys, dys) | |||||
| def __enter__(self): | |||||
| return self | |||||
| def __exit__(self, _1, _2, _3): | |||||
| del self._impl | |||||
| @@ -6,11 +6,18 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import numpy as np | |||||
| from .._imperative_rt.core2 import Tensor | |||||
| from ..tensor.core import OpBase, TensorBase, apply | from ..tensor.core import OpBase, TensorBase, apply | ||||
| class Const(OpBase): | |||||
| class Const: | |||||
| def __init__(self, value=None, *, dtype=None, device=None): | def __init__(self, value=None, *, dtype=None, device=None): | ||||
| self.value = value | |||||
| self.value = np.asarray(value, dtype=dtype) | |||||
| self.dtype = dtype | self.dtype = dtype | ||||
| self.device = device | self.device = device | ||||
| def __call__(self, *reference): | |||||
| Wrapper = type(reference[0]) | |||||
| return (Wrapper(self.value, self.dtype, self.device),) | |||||
| @@ -13,9 +13,17 @@ import sys | |||||
| import typing | import typing | ||||
| from abc import ABC | from abc import ABC | ||||
| from .._imperative_rt.core2 import apply as apply2 | |||||
| from .multipledispatch import Dispatcher | from .multipledispatch import Dispatcher | ||||
| def apply_op(op, *args): | |||||
| Wrapper = type(args[0]) | |||||
| args = [arg._tensor for arg in args] | |||||
| results = apply2(op, *args) | |||||
| return tuple(map(Wrapper, results)) | |||||
| class OpBase(ABC): | class OpBase(ABC): | ||||
| def __call__(self, *args): | def __call__(self, *args): | ||||
| return apply(self, *args) | return apply(self, *args) | ||||
| @@ -10,10 +10,10 @@ from typing import Iterable | |||||
| import numpy as np | import numpy as np | ||||
| from .._imperative_rt.core2 import Tensor, apply | |||||
| from .._trace_option import use_symbolic_shape | from .._trace_option import use_symbolic_shape | ||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from ..ops.special import Const | from ..ops.special import Const | ||||
| from .core import TensorBase, TensorWrapperBase, apply | |||||
| from .utils import astensor1d, isscalar, make_shape_tuple | from .utils import astensor1d, isscalar, make_shape_tuple | ||||
| @@ -149,13 +149,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
| return True | return True | ||||
| def get_index(i): | def get_index(i): | ||||
| if not isinstance(i, (TensorBase, TensorWrapperBase)): | |||||
| if not isinstance(i, (Tensor)): | |||||
| if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | ||||
| (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | ||||
| else: | else: | ||||
| (i,) = Const(i, dtype=np.int32, device=inp.device)(inp) | (i,) = Const(i, dtype=np.int32, device=inp.device)(inp) | ||||
| return i | return i | ||||
| assert isinstance(i, (TensorBase, TensorWrapperBase)) | |||||
| assert isinstance(i, Tensor) | |||||
| if i.dtype != np.bool_: | if i.dtype != np.bool_: | ||||
| return i | return i | ||||
| _, ind = apply(builtin.CondTake(), i, i) | _, ind = apply(builtin.CondTake(), i, i) | ||||
| @@ -198,8 +198,8 @@ def try_condtake(tensor, index): | |||||
| return [] | return [] | ||||
| if isinstance(index, np.ndarray): | if isinstance(index, np.ndarray): | ||||
| (index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) | (index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) | ||||
| assert isinstance(index, (TensorBase, TensorWrapperBase)) | |||||
| if not isinstance(tensor, (TensorWrapperBase, TensorBase)): | |||||
| assert isinstance(index, Tensor) | |||||
| if not isinstance(tensor, Tensor): | |||||
| raise TypeError("input must be a tensor") | raise TypeError("input must be a tensor") | ||||
| if tensor.device != index.device: | if tensor.device != index.device: | ||||
| raise ValueError( | raise ValueError( | ||||
| @@ -227,7 +227,7 @@ def getitem(tensor, index): | |||||
| op = builtin.IndexingMultiAxisVec(items=items) | op = builtin.IndexingMultiAxisVec(items=items) | ||||
| (result,) = apply(op, tensor, *tensors) | (result,) = apply(op, tensor, *tensors) | ||||
| if ret_scalar: | if ret_scalar: | ||||
| result.__wrapped__._data._isscalar = True | |||||
| result.setscalar() | |||||
| return result | return result | ||||
| @@ -239,7 +239,7 @@ def setitem(tensor, index, value): | |||||
| if index.shape[0] == 0: | if index.shape[0] == 0: | ||||
| return tensor | return tensor | ||||
| tensor = tensor.reshape(-1) | tensor = tensor.reshape(-1) | ||||
| if not isinstance(value, (TensorBase, TensorWrapperBase)): | |||||
| if not isinstance(value, Tensor): | |||||
| op = Const(value, dtype=tensor.dtype, device=tensor.device) | op = Const(value, dtype=tensor.dtype, device=tensor.device) | ||||
| (value,) = op(tensor) | (value,) = op(tensor) | ||||
| tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | ||||
| @@ -250,6 +250,7 @@ def setitem(tensor, index, value): | |||||
| op = builtin.Subtensor(items=items) | op = builtin.Subtensor(items=items) | ||||
| else: | else: | ||||
| op = builtin.IndexingMultiAxisVec(items=items) | op = builtin.IndexingMultiAxisVec(items=items) | ||||
| (tmp_result,) = apply(op, tensor, *tensors) | (tmp_result,) = apply(op, tensor, *tensors) | ||||
| # XXX: broadcast can always be applied even if shapes are equal | # XXX: broadcast can always be applied even if shapes are equal | ||||
| @@ -8,19 +8,20 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import abc | import abc | ||||
| import collections | import collections | ||||
| from typing import Union | |||||
| import numpy as np | import numpy as np | ||||
| from .._imperative_rt.common import CompNode | |||||
| from .._imperative_rt.core2 import Tensor, apply | |||||
| from .._trace_option import use_symbolic_shape | from .._trace_option import use_symbolic_shape | ||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from ..ops.builtin import Elemwise, GetVarShape | from ..ops.builtin import Elemwise, GetVarShape | ||||
| from ..ops.special import Const | from ..ops.special import Const | ||||
| from . import utils | from . import utils | ||||
| from .core import OpBase, TensorBase, TensorWrapperBase, apply | |||||
| from .core import OpBase, TensorBase, TensorWrapperBase | |||||
| from .indexing import getitem as _getitem | from .indexing import getitem as _getitem | ||||
| from .indexing import setitem as _setitem | from .indexing import setitem as _setitem | ||||
| from .raw_tensor import RawTensor, as_raw_tensor | |||||
| from .tensor import Tensor | |||||
| from .utils import isscalar | from .utils import isscalar | ||||
| from .utils import make_shape_tuple as _make_shape_tuple | from .utils import make_shape_tuple as _make_shape_tuple | ||||
| from .utils import setscalar | from .utils import setscalar | ||||
| @@ -41,6 +42,7 @@ def _elwise(*args, mode): | |||||
| ) | ) | ||||
| args = utils.convert_inputs(*args) | args = utils.convert_inputs(*args) | ||||
| (result,) = apply(op, *args) | (result,) = apply(op, *args) | ||||
| _isscalar = True | _isscalar = True | ||||
| for i in args: | for i in args: | ||||
| if isscalar(i) == False: | if isscalar(i) == False: | ||||
| @@ -84,9 +86,7 @@ def _reshape(x, shape): | |||||
| if unspec_axis is not None: | if unspec_axis is not None: | ||||
| raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) | raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) | ||||
| unspec_axis = i | unspec_axis = i | ||||
| shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) | shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) | ||||
| if unspec_axis is None: | if unspec_axis is None: | ||||
| op = builtin.Reshape() | op = builtin.Reshape() | ||||
| else: | else: | ||||
| @@ -181,7 +181,6 @@ def _reduce(mode): | |||||
| elif isinstance(axis, collections.abc.Iterable): | elif isinstance(axis, collections.abc.Iterable): | ||||
| axis = list(axis) | axis = list(axis) | ||||
| axis.sort(reverse=True) | axis.sort(reverse=True) | ||||
| for ai in axis: | for ai in axis: | ||||
| op = builtin.Reduce(mode=mode, axis=ai) | op = builtin.Reduce(mode=mode, axis=ai) | ||||
| (data,) = apply(op, data) | (data,) = apply(op, data) | ||||
| @@ -221,10 +220,7 @@ def _todo(*_): | |||||
| def _expand_args(args): | def _expand_args(args): | ||||
| if len(args) == 1: | if len(args) == 1: | ||||
| if isinstance( | |||||
| args[0], | |||||
| (collections.abc.Sequence, TensorBase, TensorWrapperBase, np.ndarray), | |||||
| ): | |||||
| if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),): | |||||
| args = args[0] | args = args[0] | ||||
| return args | return args | ||||
| @@ -240,9 +236,8 @@ class ArrayMethodMixin(abc.ABC): | |||||
| return self.numpy().astype(dtype) | return self.numpy().astype(dtype) | ||||
| def __array_wrap__(self, array): | def __array_wrap__(self, array): | ||||
| return TensorWrapper( | |||||
| as_raw_tensor(array, dtype=array.dtype, device=self.device) | |||||
| ) | |||||
| Wrapper = type(self) | |||||
| return Wrapper(array, dtype=array.dtype, device=self.device) | |||||
| @abc.abstractmethod | @abc.abstractmethod | ||||
| def _reset(self, other): | def _reset(self, other): | ||||
| @@ -253,7 +248,11 @@ class ArrayMethodMixin(abc.ABC): | |||||
| pass | pass | ||||
| @abc.abstractproperty | @abc.abstractproperty | ||||
| def shape(self) -> tuple: | |||||
| def shape(self) -> Union[tuple, Tensor]: | |||||
| pass | |||||
| @abc.abstractproperty | |||||
| def _tuple_shape(self) -> tuple: | |||||
| pass | pass | ||||
| @abc.abstractmethod | @abc.abstractmethod | ||||
| @@ -331,7 +330,7 @@ class ArrayMethodMixin(abc.ABC): | |||||
| __complex__ = lambda self: complex(self.item()) | __complex__ = lambda self: complex(self.item()) | ||||
| def __len__(self): | def __len__(self): | ||||
| shape = self.__wrapped__.shape | |||||
| shape = self._tuple_shape | |||||
| if shape: | if shape: | ||||
| return int(shape[0]) | return int(shape[0]) | ||||
| raise TypeError("ndim is 0") | raise TypeError("ndim is 0") | ||||
| @@ -352,7 +351,7 @@ class ArrayMethodMixin(abc.ABC): | |||||
| @property | @property | ||||
| def ndim(self): | def ndim(self): | ||||
| shape = self.__wrapped__.shape | |||||
| shape = self._tuple_shape | |||||
| if shape is None: | if shape is None: | ||||
| raise ValueError("unkown ndim") | raise ValueError("unkown ndim") | ||||
| return len(shape) | return len(shape) | ||||
| @@ -480,22 +479,52 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase): | |||||
| self.__wrapped__._swap_out() | self.__wrapped__._swap_out() | ||||
| class TensorWrapper(GenericTensorWrapper): | |||||
| def __init__(self, data, dtype=None, device=None): | |||||
| if isinstance(data, TensorWrapperBase): | |||||
| data = data.__wrapped__ | |||||
| elif not isinstance(data, TensorBase): | |||||
| assert data is not None, "Cannot init a tensor with data as None" | |||||
| data = Tensor(as_raw_tensor(data, dtype=dtype, device=device)) | |||||
| super().__init__(data) | |||||
| class TensorWrapper(ArrayMethodMixin, TensorBase): | |||||
| def __init__(self, data, dtype=None, device=None, isscalar=False): | |||||
| self._isscalar = isscalar | |||||
| if isinstance(data, Tensor): | |||||
| self._tensor = data | |||||
| else: | |||||
| if device is None: | |||||
| device = CompNode._get_default_device() | |||||
| self._tensor = Tensor(data, dtype, device) | |||||
| def _reset(self, other): | def _reset(self, other): | ||||
| if isinstance(other, TensorWrapperBase): | |||||
| self.__wrapped__ = other.__wrapped__ | |||||
| elif isinstance(other, TensorBase): | |||||
| self.__wrapped__ = other | |||||
| else: | |||||
| self._reset(type(self)(other, dtype=self.dtype, device=self.device)) | |||||
| if not isinstance(other, __class__): | |||||
| raise TypeError(type(other)) | |||||
| self._tensor = other._tensor | |||||
| return self | |||||
| @property | |||||
| def dtype(self): | |||||
| return self._tensor.dtype | |||||
| @property | |||||
| def shape(self): | |||||
| if self._isscalar: | |||||
| return () | |||||
| shape = self._tensor.shape | |||||
| if shape == () or not use_symbolic_shape(): | |||||
| return shape | |||||
| return apply(GetVarShape(), self)[0] | |||||
| @property | |||||
| def device(self): | |||||
| return self._tensor.device | |||||
| def numpy(self): | |||||
| if self._isscalar: | |||||
| return self._tensor.numpy().squeeze() | |||||
| return self._tensor.numpy() | |||||
| def _drop(self): | |||||
| self._tensor._drop() | |||||
| def _swap_in(self): | |||||
| self._tensor._swap_in() | |||||
| def _swap_out(self): | |||||
| self._tensor._swap_out() | |||||
| def __repr__(self): | def __repr__(self): | ||||
| piece = "Tensor(" | piece = "Tensor(" | ||||
| @@ -11,9 +11,10 @@ from typing import Iterable, Union | |||||
| import numpy as np | import numpy as np | ||||
| from .._imperative_rt.core2 import Tensor, apply | |||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from ..ops.special import Const | from ..ops.special import Const | ||||
| from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||||
| from ..tensor.core import OpBase, TensorBase, TensorWrapperBase | |||||
| from .dtype import is_equal, is_quantize | from .dtype import is_equal, is_quantize | ||||
| _enable_convert_inputs = True | _enable_convert_inputs = True | ||||
| @@ -109,7 +110,7 @@ def dtype_promotion(inputs): | |||||
| def get_device(inputs): | def get_device(inputs): | ||||
| device = None | device = None | ||||
| for i in inputs: | for i in inputs: | ||||
| if isinstance(i, (TensorWrapperBase, TensorBase)): | |||||
| if isinstance(i, Tensor): | |||||
| if device is None: | if device is None: | ||||
| device = i.device | device = i.device | ||||
| elif device != i.device: | elif device != i.device: | ||||
| @@ -126,30 +127,31 @@ def concatenate(inputs, axis=0, *, device=None): | |||||
| return convert_single_value(x, inputs, dtype=dtype) | return convert_single_value(x, inputs, dtype=dtype) | ||||
| inputs = tuple(map(convert, inputs)) | inputs = tuple(map(convert, inputs)) | ||||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inputs) | |||||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) | |||||
| return result | return result | ||||
| def astype(x, dtype): | def astype(x, dtype): | ||||
| dtype = np.dtype(dtype) | dtype = np.dtype(dtype) | ||||
| if not is_equal(x.dtype, dtype): | if not is_equal(x.dtype, dtype): | ||||
| isscalar = x.__wrapped__._data._isscalar | |||||
| isscalar = x.isscalar() | |||||
| (x,) = apply(builtin.TypeCvt(dtype=dtype), x) | (x,) = apply(builtin.TypeCvt(dtype=dtype), x) | ||||
| x.__wrapped__._data._isscalar = isscalar | |||||
| if isscalar: | |||||
| x.setscalar() | |||||
| return x | return x | ||||
| def convert_single_value(v, inputs, *, dtype=None, device=None): | def convert_single_value(v, inputs, *, dtype=None, device=None): | ||||
| tensors = [i for i in inputs if isinstance(i, (TensorBase, TensorWrapperBase))] | |||||
| tensors = [i for i in inputs if isinstance(i, Tensor)] | |||||
| assert len(tensors) > 0 | assert len(tensors) > 0 | ||||
| if isinstance(v, (TensorWrapperBase, TensorBase)): | |||||
| if isinstance(v, (TensorWrapperBase, Tensor)): | |||||
| v = astype(v, v.dtype if is_quantize(v.dtype) else dtype) | v = astype(v, v.dtype if is_quantize(v.dtype) else dtype) | ||||
| else: | else: | ||||
| (v,) = Const(v, dtype=dtype, device=device)(*tensors) | (v,) = Const(v, dtype=dtype, device=device)(*tensors) | ||||
| return v | return v | ||||
| def convert_inputs(*args: TensorBase): | |||||
| def convert_inputs(*args: Tensor): | |||||
| if not _enable_convert_inputs: | if not _enable_convert_inputs: | ||||
| return args | return args | ||||
| @@ -167,7 +169,7 @@ def convert_inputs(*args: TensorBase): | |||||
| def result_type(*args): | def result_type(*args): | ||||
| dtypes = [] | dtypes = [] | ||||
| for i in args: | for i in args: | ||||
| if isinstance(i, (TensorWrapperBase, TensorBase)): | |||||
| if isinstance(i, Tensor): | |||||
| dtypes.append(i.dtype) | dtypes.append(i.dtype) | ||||
| continue | continue | ||||
| try: | try: | ||||
| @@ -178,25 +180,16 @@ def result_type(*args): | |||||
| def isscalar(x): | def isscalar(x): | ||||
| if isinstance(x, TensorWrapperBase): | |||||
| x = x.__wrapped__ | |||||
| if hasattr(x, "_isscalar"): | |||||
| return x._isscalar | |||||
| if isinstance(x, TensorBase): | |||||
| return x._data._isscalar | |||||
| if isinstance(x, Tensor): | |||||
| return x.isscalar() | |||||
| return np.isscalar(x) | return np.isscalar(x) | ||||
| def setscalar(x): | def setscalar(x): | ||||
| if isinstance(x, TensorWrapperBase): | |||||
| x = x.__wrapped__ | |||||
| if hasattr(x, "_isscalar"): | |||||
| x._isscalar = True | |||||
| elif isinstance(x, TensorBase): | |||||
| x._data._isscalar = True | |||||
| if isinstance(x, Tensor): | |||||
| x.setscalar() | |||||
| else: | else: | ||||
| raise NotImplementedError("Unsupport type {}".format(type(x))) | raise NotImplementedError("Unsupport type {}".format(type(x))) | ||||
| @@ -215,25 +208,24 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
| else: | else: | ||||
| if ndim != 0 and ndim != 1: | if ndim != 0 and ndim != 1: | ||||
| raise ValueError("ndim != 1 or 0, get : %d" % ndim) | raise ValueError("ndim != 1 or 0, get : %d" % ndim) | ||||
| if not isinstance(x, (TensorBase, TensorWrapperBase)): | |||||
| if not isinstance(x, Tensor): | |||||
| (x,) = Const(x, dtype=dtype, device=device)(*reference) | (x,) = Const(x, dtype=dtype, device=device)(*reference) | ||||
| return x | return x | ||||
| if not isinstance(x, collections.abc.Sequence): | if not isinstance(x, collections.abc.Sequence): | ||||
| raise TypeError | raise TypeError | ||||
| if any(isinstance(i, (TensorBase, TensorWrapperBase)) for i in x): | |||||
| if any(isinstance(i, Tensor) for i in x): | |||||
| x = concatenate(x, device=device) | x = concatenate(x, device=device) | ||||
| if dtype is not None: | if dtype is not None: | ||||
| x = astype(x, dtype) | x = astype(x, dtype) | ||||
| return x | return x | ||||
| (x,) = Const(x, dtype=dtype, device=device)(*reference) | (x,) = Const(x, dtype=dtype, device=device)(*reference) | ||||
| return x | return x | ||||
| def _expand_int(s, i): | def _expand_int(s, i): | ||||
| if isinstance(i, (TensorBase, TensorWrapperBase)): | |||||
| if isinstance(i, Tensor): | |||||
| i_np = i.numpy() | i_np = i.numpy() | ||||
| if i_np.ndim == 0: | if i_np.ndim == 0: | ||||
| s.append(int(i_np)) | s.append(int(i_np)) | ||||
| @@ -8,6 +8,7 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from typing import Optional, Tuple | from typing import Optional, Tuple | ||||
| from ..core._imperative_rt.core2 import apply | |||||
| from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn | from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn | ||||
| from ..core.autodiff.grad import ( | from ..core.autodiff.grad import ( | ||||
| Tracer, | Tracer, | ||||
| @@ -17,7 +18,6 @@ from ..core.autodiff.grad import ( | |||||
| tracer_apply, | tracer_apply, | ||||
| ) | ) | ||||
| from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | ||||
| from ..core.tensor.core import apply | |||||
| from ..core.tensor.tensor import Tensor, tensor_apply | from ..core.tensor.tensor import Tensor, tensor_apply | ||||
| from ..device import get_default_device | from ..device import get_default_device | ||||
| from ..tensor import tensor | from ..tensor import tensor | ||||
| @@ -39,71 +39,6 @@ __all__ = [ | |||||
| ] | ] | ||||
| @apply.register() | |||||
| def _(op: RemoteSend, *args: Tensor): | |||||
| ret = tensor_apply(op, *args) | |||||
| # set extra information | |||||
| tracer_set = dict() | |||||
| for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))): | |||||
| tracer_set[k.name] = True | |||||
| # check tracer_set in remote_recv | |||||
| get_client().set_remote_tracer(op.key, tracer_set) | |||||
| return ret | |||||
| @builtin_op_get_backward_fn.register(RemoteSend) | |||||
| def _(op: RemoteSend, inputs, outputs, input_requires_grad): | |||||
| def backward(*args): | |||||
| return [ | |||||
| remote_recv( | |||||
| op.rank_to, | |||||
| inputs[0].shape, | |||||
| inputs[0].dtype, | |||||
| device=str(inputs[0].device), | |||||
| inp=inputs[0], | |||||
| ) | |||||
| ] | |||||
| return backward, [True] | |||||
| @get_op_has_grad_fn.register(RemoteSend) | |||||
| def _(op: RemoteSend): | |||||
| def has_grad(opnode, reached): | |||||
| return get_client().check_is_grad(op.key) | |||||
| return has_grad | |||||
| @check_backward_allow_noinput.register(RemoteSend) | |||||
| def _(op: RemoteSend): | |||||
| return True | |||||
| @builtin_op_get_backward_fn.register(RemoteRecv) | |||||
| def _(op: RemoteRecv, inputs, outputs, input_requires_grad): | |||||
| def backward(*output_grads): | |||||
| return [remote_send(output_grads[0], op.rank_from)] | |||||
| return backward, [True] | |||||
| @get_op_has_grad_fn.register(RemoteRecv) | |||||
| def _(op: RemoteRecv): | |||||
| def has_grad(opnode, reached): | |||||
| ret = False | |||||
| for v in opnode.outputs: | |||||
| if v() in reached: | |||||
| ret = True | |||||
| break | |||||
| get_client().set_is_grad(op.key, ret) | |||||
| return ret | |||||
| return has_grad | |||||
| def collective_comm(inp, mode, group, device): | def collective_comm(inp, mode, group, device): | ||||
| """Helper function for applying collective communication functions.""" | """Helper function for applying collective communication functions.""" | ||||
| assert isinstance(group, Group) | assert isinstance(group, Group) | ||||
| @@ -17,8 +17,8 @@ import numpy as np | |||||
| from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager | from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager | ||||
| from megengine.device import get_default_device, get_device_count | from megengine.device import get_default_device, get_device_count | ||||
| from ..core._imperative_rt.core2 import apply | |||||
| from ..core.ops.builtin import ParamPackConcat, ParamPackSplit | from ..core.ops.builtin import ParamPackConcat, ParamPackSplit | ||||
| from ..core.tensor.core import apply | |||||
| from ..functional.utils import copy | from ..functional.utils import copy | ||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from ..utils.future import Future | from ..utils.future import Future | ||||
| @@ -228,7 +228,6 @@ class AllreduceCallback: | |||||
| self._packing_size[dtype] = 0 | self._packing_size[dtype] = 0 | ||||
| def __call__(self, param, grad): | def __call__(self, param, grad): | ||||
| param = param.__wrapped__ | |||||
| gm = get_backwarding_grad_manager() | gm = get_backwarding_grad_manager() | ||||
| assert isinstance(gm, GradManager) | assert isinstance(gm, GradManager) | ||||
| if gm not in self._marked_gm: | if gm not in self._marked_gm: | ||||
| @@ -9,10 +9,10 @@ | |||||
| # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | ||||
| import functools | import functools | ||||
| from ..core._imperative_rt.core2 import apply | |||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.builtin import Elemwise | from ..core.ops.builtin import Elemwise | ||||
| from ..core.tensor import megbrain_graph, utils | from ..core.tensor import megbrain_graph, utils | ||||
| from ..core.tensor.core import apply | |||||
| from ..core.tensor.utils import isscalar, setscalar | from ..core.tensor.utils import isscalar, setscalar | ||||
| from ..device import get_default_device | from ..device import get_default_device | ||||
| from ..jit.tracing import is_tracing | from ..jit.tracing import is_tracing | ||||
| @@ -12,10 +12,11 @@ import math | |||||
| import numbers | import numbers | ||||
| from typing import Optional, Sequence, Tuple, Union | from typing import Optional, Sequence, Tuple, Union | ||||
| from ..core._imperative_rt.core2 import apply | |||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor import utils | from ..core.tensor import utils | ||||
| from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | |||||
| from ..core.tensor.core import TensorBase, TensorWrapperBase | |||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from .elemwise import clip, exp, log, log1p | from .elemwise import clip, exp, log, log1p | ||||
| from .tensor import reshape, squeeze | from .tensor import reshape, squeeze | ||||
| @@ -10,12 +10,12 @@ | |||||
| from typing import Optional, Sequence, Tuple, Union | from typing import Optional, Sequence, Tuple, Union | ||||
| from ..core._imperative_rt import CompNode | from ..core._imperative_rt import CompNode | ||||
| from ..core._imperative_rt.core2 import Tensor, apply | |||||
| from ..core._trace_option import use_symbolic_shape | from ..core._trace_option import use_symbolic_shape | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.builtin import BatchNorm | from ..core.ops.builtin import BatchNorm | ||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor import megbrain_graph, utils | from ..core.tensor import megbrain_graph, utils | ||||
| from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | |||||
| from ..core.tensor.utils import astensor1d | from ..core.tensor.utils import astensor1d | ||||
| from ..distributed import WORLD, is_distributed | from ..distributed import WORLD, is_distributed | ||||
| from ..jit.tracing import is_tracing | from ..jit.tracing import is_tracing | ||||
| @@ -1565,9 +1565,7 @@ def indexing_one_hot( | |||||
| [1.] | [1.] | ||||
| """ | """ | ||||
| assert isinstance( | |||||
| src, (TensorWrapperBase, TensorBase) | |||||
| ), "src must be of Tensor type" | |||||
| assert isinstance(src, Tensor), "src must be of Tensor type" | |||||
| op = builtin.IndexingOneHot(axis=axis) | op = builtin.IndexingOneHot(axis=axis) | ||||
| index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device) | index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device) | ||||
| (result,) = apply(op, src, index) | (result,) = apply(op, src, index) | ||||
| @@ -8,8 +8,8 @@ | |||||
| # pylint: disable=too-many-lines | # pylint: disable=too-many-lines | ||||
| from typing import Tuple, Union | from typing import Tuple, Union | ||||
| from ..core._imperative_rt.core2 import apply | |||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.tensor.core import apply | |||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from .debug_param import get_conv_execution_strategy | from .debug_param import get_conv_execution_strategy | ||||
| from .types import _pair, _pair_nonzero | from .types import _pair, _pair_nonzero | ||||
| @@ -14,10 +14,10 @@ from typing import Iterable, List, Optional, Sequence, Tuple, Union | |||||
| import numpy as np | import numpy as np | ||||
| from ..core._imperative_rt import CompNode | from ..core._imperative_rt import CompNode | ||||
| from ..core._imperative_rt.core2 import Tensor, apply | |||||
| from ..core._wrap import device as as_device | from ..core._wrap import device as as_device | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | |||||
| from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis | from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis | ||||
| from ..core.tensor.utils import ( | from ..core.tensor.utils import ( | ||||
| astensor1d, | astensor1d, | ||||
| @@ -611,11 +611,11 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: | |||||
| """ | """ | ||||
| x, y = convert_inputs(x, y) | x, y = convert_inputs(x, y) | ||||
| if not isinstance(x, (TensorWrapperBase, TensorBase)): | |||||
| if not isinstance(x, Tensor): | |||||
| raise TypeError("input x must be a tensor") | raise TypeError("input x must be a tensor") | ||||
| if not isinstance(y, (TensorWrapperBase, TensorBase)): | |||||
| if not isinstance(y, Tensor): | |||||
| raise TypeError("input y must be a tensor") | raise TypeError("input y must be a tensor") | ||||
| if not isinstance(mask, (TensorWrapperBase, TensorBase)): | |||||
| if not isinstance(mask, Tensor): | |||||
| raise TypeError("mask must be a tensor") | raise TypeError("mask must be a tensor") | ||||
| if mask.dtype != np.bool_: | if mask.dtype != np.bool_: | ||||
| raise ValueError("mask must be bool") | raise ValueError("mask must be bool") | ||||
| @@ -668,9 +668,9 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor: | |||||
| [1. 4.] [0 3] | [1. 4.] [0 3] | ||||
| """ | """ | ||||
| if not isinstance(x, (TensorWrapperBase, TensorBase)): | |||||
| if not isinstance(x, Tensor): | |||||
| raise TypeError("input must be a tensor") | raise TypeError("input must be a tensor") | ||||
| if not isinstance(mask, (TensorWrapperBase, TensorBase)): | |||||
| if not isinstance(mask, Tensor): | |||||
| raise TypeError("mask must be a tensor") | raise TypeError("mask must be a tensor") | ||||
| if mask.dtype != np.bool_: | if mask.dtype != np.bool_: | ||||
| raise ValueError("mask must be bool") | raise ValueError("mask must be bool") | ||||
| @@ -11,10 +11,10 @@ from typing import Iterable, Union | |||||
| import numpy as np | import numpy as np | ||||
| from ..core._imperative_rt.core2 import apply | |||||
| from ..core._wrap import device as as_device | from ..core._wrap import device as as_device | ||||
| from ..core.ops.builtin import Copy, Identity | from ..core.ops.builtin import Copy, Identity | ||||
| from ..core.tensor import Tensor | |||||
| from ..core.tensor.core import apply | |||||
| from ..tensor import Tensor | |||||
| from .math import topk as _topk | from .math import topk as _topk | ||||
| from .tensor import broadcast_to, transpose | from .tensor import broadcast_to, transpose | ||||
| @@ -10,9 +10,9 @@ from typing import Iterable, Optional | |||||
| from .. import Tensor | from .. import Tensor | ||||
| from ..core._imperative_rt import invoke_op | from ..core._imperative_rt import invoke_op | ||||
| from ..core._imperative_rt.core2 import apply | |||||
| from ..core.ops.builtin import GaussianRNG, UniformRNG | from ..core.ops.builtin import GaussianRNG, UniformRNG | ||||
| from ..core.tensor import utils | from ..core.tensor import utils | ||||
| from ..core.tensor.core import apply | |||||
| from .rng import _random_seed_generator | from .rng import _random_seed_generator | ||||
| __all__ = ["normal", "uniform"] | __all__ = ["normal", "uniform"] | ||||
| @@ -10,26 +10,66 @@ | |||||
| import collections | import collections | ||||
| from .core import Tensor as _Tensor | |||||
| from .core.ops.builtin import Copy | |||||
| from .core.tensor.core import apply | |||||
| import numpy as np | |||||
| from .core._imperative_rt import CompNode | |||||
| from .core._imperative_rt.core2 import Tensor as _Tensor | |||||
| from .core._imperative_rt.core2 import apply | |||||
| from .core._trace_option import use_symbolic_shape | |||||
| from .core.ops.builtin import Copy, GetVarShape | |||||
| from .core.tensor.raw_tensor import as_device | from .core.tensor.raw_tensor import as_device | ||||
| from .core.tensor.tensor_wrapper import ArrayMethodMixin | |||||
| from .device import _valid_device, get_default_device | from .device import _valid_device, get_default_device | ||||
| from .utils.deprecation import deprecated | from .utils.deprecation import deprecated | ||||
| class Tensor(_Tensor): | |||||
| class Tensor(_Tensor, ArrayMethodMixin): | |||||
| grad = None | grad = None | ||||
| dmap_callback = None | dmap_callback = None | ||||
| q_dict = {"mode": None, "scale": None, "zero_point": None} | |||||
| def __init__(self, data, dtype=None, device=None): | |||||
| def __new__(cls, data, dtype=None, device=None): | |||||
| if device is None: | if device is None: | ||||
| device = get_default_device() | |||||
| self.q_dict = {"mode": None, "scale": None, "zero_point": None} | |||||
| super().__init__(data, dtype=dtype, device=device) | |||||
| cn = get_default_device() | |||||
| elif isinstance(device, str): | |||||
| if cls.dmap_callback is not None: | |||||
| cn = CompNode(cls.dmap_callback(device)) | |||||
| else: | |||||
| cn = CompNode(device) | |||||
| else: | |||||
| assert isinstance(device, CompNode) | |||||
| cn = device | |||||
| if isinstance(data, _Tensor): | |||||
| obj = _Tensor.__new__(cls, data) | |||||
| else: | |||||
| obj = _Tensor.__new__(cls, data, dtype, cn) | |||||
| return obj | |||||
| @property | |||||
| def shape(self): | |||||
| shape = super().shape | |||||
| if shape == () or not use_symbolic_shape(): | |||||
| return shape | |||||
| return apply(GetVarShape(), self)[0] | |||||
| @property | |||||
| def _tuple_shape(self): | |||||
| return super().shape | |||||
| def __repr__(self): | |||||
| piece = "Tensor(" | |||||
| with np.printoptions(precision=4, suppress=True): | |||||
| piece += "{}".format(str(self.numpy())) | |||||
| if self.dtype != np.float32: | |||||
| piece += ", dtype={}".format(np.dtype(self.dtype).name) | |||||
| piece += ", device={}".format(self.device) + ")" | |||||
| return piece | |||||
| @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") | @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") | ||||
| def set_value(self, value): | def set_value(self, value): | ||||
| if not isinstance(value, _Tensor): | |||||
| value = Tensor(value, dtype=self.dtype, device=self.device) | |||||
| self._reset(value) | self._reset(value) | ||||
| @deprecated(version="1.0", reason="use *= 0 instead") | @deprecated(version="1.0", reason="use *= 0 instead") | ||||
| @@ -61,27 +101,22 @@ class Tensor(_Tensor): | |||||
| def __hash__(self): | def __hash__(self): | ||||
| return id(self) | return id(self) | ||||
| def __getnewargs__(self): | |||||
| r""" __getnewargs__ will be called for pickle serialization or deep copy | |||||
| """ | |||||
| return (self.numpy(), self.dtype, self.device.logical_name) | |||||
| def __getstate__(self): | def __getstate__(self): | ||||
| r""" __getstate__ will be called for pickle serialization or deep copy | r""" __getstate__ will be called for pickle serialization or deep copy | ||||
| """ | """ | ||||
| state = { | state = { | ||||
| "data": self.numpy(), | |||||
| "device": self.device.logical_name, | |||||
| "dtype": self.dtype, | |||||
| "qdict": self.q_dict, | "qdict": self.q_dict, | ||||
| } | } | ||||
| return state | return state | ||||
| def __setstate__(self, state): | def __setstate__(self, state): | ||||
| data = state.pop("data") | |||||
| logical_device = state.pop("device") | |||||
| if self.dmap_callback is not None: | |||||
| assert isinstance(logical_device, str) | |||||
| logical_device = self.dmap_callback(logical_device) | |||||
| dtype = state.pop("dtype") | |||||
| self.q_dict = state.pop("qdict") | self.q_dict = state.pop("qdict") | ||||
| super().__init__(data, dtype=dtype, device=logical_device) | |||||
| def detach(self): | def detach(self): | ||||
| r""" | r""" | ||||
| @@ -89,8 +124,7 @@ class Tensor(_Tensor): | |||||
| during backward gradient calcuation, i.e. its gradient is zero. | during backward gradient calcuation, i.e. its gradient is zero. | ||||
| """ | """ | ||||
| Wrapper = type(self) | Wrapper = type(self) | ||||
| Tensor = type(self.__wrapped__) | |||||
| return Wrapper(Tensor(self.__wrapped__._data)) | |||||
| return Wrapper(self) | |||||
| tensor = Tensor | tensor = Tensor | ||||
| @@ -0,0 +1,404 @@ | |||||
| /** | |||||
| * \file imperative/python/src/grad.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "./grad.h" | |||||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| #include "megbrain/utils/mempool.h" | |||||
| namespace py = pybind11; | |||||
| namespace mgb::imperative::python { | |||||
| namespace { | |||||
| struct GradSlotWeakPtr { | |||||
| std::weak_ptr<GradFn> grad_fn; | |||||
| size_t idx; | |||||
| }; | |||||
| } // namespace | |||||
| struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> { | |||||
| using Base = intrusive_list::Node<GradProducerRecord>; | |||||
| GradProducerRecord() = default; | |||||
| GradProducerRecord(GradProducerRecord::head_t& head) : Base(intrusive_list::after_t{}, head) {} | |||||
| // GradProducerRecord(GradProducerRecord&&) = default; | |||||
| // GradProducerRecord& operator=(GradProducerRecord&) = default; | |||||
| // GradProducerRecord& operator=(GradProducerRecord&&) = default; | |||||
| }; | |||||
| struct GradSlot { | |||||
| std::shared_ptr<Tensor> grad; | |||||
| py::object callback; | |||||
| GradProducerRecord::head_t producer_head; | |||||
| }; | |||||
| struct GradSlotProducerPtr : GradSlotPtr { | |||||
| GradProducerRecord producer_record; | |||||
| GradSlotProducerPtr() = default; | |||||
| GradSlotProducerPtr(GradInfo& info) : GradSlotPtr(info), producer_record(info->producer_head) {} | |||||
| }; | |||||
| struct GradFn : std::enable_shared_from_this<GradFn> { | |||||
| static MemPool<GradFn> pool; | |||||
| std::weak_ptr<GradKey> key; | |||||
| SmallVector<GradSlot> slots; | |||||
| SmallVector<GradSlotProducerPtr> dsts; | |||||
| SmallVector<std::shared_ptr<Tensor>> closure; | |||||
| std::shared_ptr<BackwardGraphResult> backward_graph; | |||||
| bool in_ref_keeper = false; | |||||
| static void deleter(GradFn* ptr) { | |||||
| pool.free(ptr); | |||||
| } | |||||
| std::shared_ptr<GradFn> make() { | |||||
| return std::shared_ptr<GradFn>(pool.alloc(), &deleter); | |||||
| } | |||||
| void clear() { | |||||
| key.reset(); | |||||
| slots.clear(); | |||||
| dsts.clear(); | |||||
| closure.clear(); | |||||
| backward_graph.reset(); | |||||
| } | |||||
| }; | |||||
| GradSlot* GradSlotPtr::operator->() { | |||||
| return &grad_fn->slots[idx]; | |||||
| } | |||||
| namespace { | |||||
| struct BackwardGraphCache : std::unordered_map<size_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject { | |||||
| std::shared_ptr<void> on_comp_node_finalize() override { | |||||
| clear(); | |||||
| return {}; | |||||
| } | |||||
| } backward_graph_cache; | |||||
| std::shared_ptr<BackwardGraphResult> make_backward_graph( | |||||
| ApplyContext& ctx, const apply_result_t& outputs) { | |||||
| // hash | |||||
| static_assert(alignof(size_t) % alignof(bool) == 0); | |||||
| size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool); | |||||
| alignas(alignof(size_t)) std::byte buf[buf_size]; | |||||
| size_t* size_t_ptr = reinterpret_cast<size_t*>(buf); | |||||
| bool* bool_ptr = reinterpret_cast<bool*>(size_t_ptr + (1 + ctx.nargs * 2)); | |||||
| bool* bool_ptr0 = bool_ptr; | |||||
| *(size_t_ptr++) = ctx.op->hash(); | |||||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
| *(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle()); | |||||
| *(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node()); | |||||
| *(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn); | |||||
| } | |||||
| mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) && | |||||
| bool_ptr == reinterpret_cast<bool*>(buf + buf_size)); | |||||
| size_t key = XXHash{}.update(buf, buf_size).digest(); | |||||
| auto&& iter = backward_graph_cache.find(key); | |||||
| if (iter != backward_graph_cache.end()) { | |||||
| return iter->second; | |||||
| } | |||||
| // slow path | |||||
| SmallVector<LogicalTensorDesc> inputs(ctx.nargs); | |||||
| SmallVector<bool> input_requires_grad(ctx.nargs, false); | |||||
| SmallVector<bool> output_has_grad(outputs.size(), true); | |||||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
| inputs[i].comp_node = ctx.args[i]->comp_node(); | |||||
| inputs[i].layout.dtype = ctx.args[i]->dtype(); | |||||
| input_requires_grad[i] = bool(ctx.args[i]->m_grad_info.grad_fn); | |||||
| } | |||||
| auto result = std::make_shared<BackwardGraphResult>( | |||||
| proxy_graph_detail::make_backward_graph( | |||||
| *ctx.op, inputs, input_requires_grad, output_has_grad)); | |||||
| if (!result->backward) { | |||||
| result.reset(); | |||||
| } | |||||
| backward_graph_cache.emplace(key, result); | |||||
| return result; | |||||
| } | |||||
| } // namespace | |||||
| apply_result_t apply_grad(ApplyContext& ctx) { | |||||
| std::shared_ptr<GradKey> grad_key; | |||||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
| auto* tensor = ctx.args[i]; | |||||
| if (tensor->m_grad_info.grad_fn) { | |||||
| auto&& input_grad_key = tensor->m_grad_info.grad_fn->key.lock(); | |||||
| // tensor is attached to a live GradKey | |||||
| if (input_grad_key && input_grad_key->active) { | |||||
| if (grad_key) { | |||||
| if (grad_key != input_grad_key) { | |||||
| PyErr_SetString(PyExc_NotImplementedError, "second order grad"); | |||||
| throw pyext17::py_err_set(); | |||||
| } | |||||
| } else { | |||||
| grad_key = std::move(input_grad_key); | |||||
| } | |||||
| } else { | |||||
| // cleanup stale grad info | |||||
| // under what condition? | |||||
| tensor->m_grad_info = {}; | |||||
| tensor->m_flags &= ~Tensor::Flags::GRAD; | |||||
| } | |||||
| } else { | |||||
| tensor->m_flags &= ~Tensor::Flags::GRAD; | |||||
| } | |||||
| } | |||||
| ctx.flags &= ~Tensor::Flags::GRAD; | |||||
| // perform forward apply_op or trace | |||||
| auto outputs = apply(ctx); | |||||
| if (!grad_key) { | |||||
| return outputs; | |||||
| } | |||||
| auto backward_graph = make_backward_graph(ctx, outputs); | |||||
| if (!backward_graph) { | |||||
| return outputs; | |||||
| } | |||||
| auto grad_fn = std::make_shared<GradFn>(); | |||||
| grad_fn->key = grad_key; | |||||
| grad_fn->slots.resize(outputs.size()); | |||||
| grad_fn->backward_graph = std::move(backward_graph); | |||||
| grad_fn->dsts.reserve(ctx.nargs); | |||||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
| if (grad_fn->backward_graph->input_has_grad[i]) { | |||||
| auto& input_grad_info = ctx.args[i]->m_grad_info; | |||||
| grad_fn->dsts.emplace_back(input_grad_info); | |||||
| grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head); | |||||
| } else { | |||||
| grad_fn->dsts.emplace_back(); | |||||
| } | |||||
| } | |||||
| auto& save_for_backward = grad_fn->backward_graph->save_for_backward; | |||||
| grad_fn->closure.reserve(std::count_if(save_for_backward.begin(), save_for_backward.end(), [](bool p){return p;})); | |||||
| // given op, taking gradient of output_tensor_list wrt input_tensor_list: | |||||
| // | |||||
| // save_for_backward[0:nargs-1]: whether input tensor requires gradient, | |||||
| // i.e., whether it is in input_tensor_list | |||||
| // | |||||
| // save_for_backward[nargs:nargs+outputs.size()-1]: whether output tensor is | |||||
| // needed to calculate gradients | |||||
| // | |||||
| // save_for_backward[-outputs.size():]: whether output tensor is in | |||||
| // output_tensor_list | |||||
| // | |||||
| // Example: perform c = a * b, where a is input data, b is parameter to be | |||||
| // optimized, save_for_backward = [1, 1, 0, 1] | |||||
| mgb_assert(ctx.nargs + 2 * outputs.size() == save_for_backward.size()); | |||||
| // record input tensors needed to take grad | |||||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
| if (save_for_backward[i]) { | |||||
| grad_fn->closure.push_back(ctx.args[i]->shared_from_this()); | |||||
| } | |||||
| } | |||||
| // record output tensors needed to take grad | |||||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||||
| bool requires_grad = save_for_backward[ctx.nargs + outputs.size() + i]; | |||||
| if (save_for_backward[ctx.nargs + i]) { | |||||
| grad_fn->closure.push_back(outputs[i]); | |||||
| if (requires_grad) { | |||||
| // avoid reference cycle [Tensor <-> GradFn] | |||||
| outputs[i] = outputs[i]->copy(); | |||||
| } | |||||
| } | |||||
| if (requires_grad) { | |||||
| auto& grad_info = outputs[i]->m_grad_info; | |||||
| grad_info.grad_fn = grad_fn; | |||||
| grad_info.idx = i; | |||||
| grad_info.insert_after(grad_key->free_vars_head); | |||||
| outputs[i]->m_flags |= Tensor::Flags::GRAD; | |||||
| } | |||||
| } | |||||
| // record forward history | |||||
| grad_key->tape.emplace_back(grad_fn); | |||||
| return outputs; | |||||
| } | |||||
| void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { | |||||
| if (nargs != 2) { | |||||
| throw py::type_error("expect 2 arguments"); | |||||
| } | |||||
| auto* tw = TensorWrapper::cast_safe(args[0]); | |||||
| if (!tw) { | |||||
| throw py::type_error("argument 1 must be Tensor"); | |||||
| } | |||||
| auto* tensor = tw->m_tensor.get(); | |||||
| py::object callback; | |||||
| if (args[1] != Py_None) { | |||||
| callback = py::reinterpret_borrow<py::object>(args[1]); | |||||
| } | |||||
| m_key->attach(tensor, std::move(callback)); | |||||
| } | |||||
| //! GradKey is weakly refered by tensor->m_grad_info.grad_fn->key after attach | |||||
| void GradKey::attach(Tensor* tensor, pybind11::object callback) { | |||||
| if (!active) { | |||||
| throw py::value_error("grad key finalized"); | |||||
| } | |||||
| if (tensor->m_grad_info.grad_fn) { | |||||
| if (tensor->m_grad_info.grad_fn->key.lock().get() != this) { | |||||
| PyErr_SetString(PyExc_NotImplementedError, "second order grad"); | |||||
| throw pyext17::py_err_set(); | |||||
| } | |||||
| if (tensor->m_grad_info->callback) { | |||||
| throw py::value_error("callback already set on this tensor"); | |||||
| } | |||||
| } else { | |||||
| tensor->m_grad_info.idx = 0; | |||||
| auto& grad_fn = tensor->m_grad_info.grad_fn; | |||||
| grad_fn = std::make_shared<GradFn>(); | |||||
| grad_fn->key = shared_from_this(); | |||||
| grad_fn->slots.resize(1); | |||||
| tensor->m_grad_info.insert_after(free_vars_head); | |||||
| tensor->m_flags |= Tensor::Flags::GRAD; | |||||
| } | |||||
| tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback); | |||||
| } | |||||
| void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta) { | |||||
| if (!grad) { | |||||
| grad = std::forward<decltype(delta)>(delta); | |||||
| return; | |||||
| } | |||||
| static ApplyContext ctx; | |||||
| if (!ctx.op) { | |||||
| ctx.op = std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD)); | |||||
| ctx.nargs = 2; | |||||
| } | |||||
| Tensor* args[2] = {grad.get(), delta.get()}; | |||||
| ctx.args = args; | |||||
| ctx.flags = grad->m_flags | delta->m_flags; | |||||
| grad = apply(ctx)[0]; | |||||
| } | |||||
| void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) { | |||||
| if (!active) { | |||||
| throw py::value_error("finalized"); | |||||
| } | |||||
| if (tensors.size() != grads.size()) { | |||||
| throw py::value_error("tensor and grad size mismatch"); | |||||
| } | |||||
| // this GradKey is marked inactive here | |||||
| active = false; | |||||
| struct CleanupGuard { | |||||
| GradKey* owner; | |||||
| CleanupGuard(GradKey* this_) : owner(this_) {} | |||||
| ~CleanupGuard() {owner->cleanup();} | |||||
| } _cleanup_guard(this); | |||||
| if (tape.empty() || grads.empty()) return; | |||||
| PyTypeObject* pytype = Py_TYPE(grads[0]->self().ptr()); | |||||
| for (size_t i = 0; i < tensors.size(); ++i) { | |||||
| auto& grad_info = tensors[i]->m_tensor->m_grad_info; | |||||
| if (grad_info.grad_fn && grad_info.grad_fn->key.lock().get() == this) { | |||||
| grad_info->grad = grads[i]->m_tensor; | |||||
| } | |||||
| } | |||||
| std::vector<std::shared_ptr<GradFn>> ref_keeper; | |||||
| ref_keeper.reserve(tape.size()); | |||||
| // back-propagation in reverse order | |||||
| for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { | |||||
| auto&& grad_fn = tape[k].lock(); | |||||
| if (!grad_fn) continue; | |||||
| if (grad_fn->backward_graph) { | |||||
| for (size_t i = 0; i < grad_fn->slots.size(); ++i) { | |||||
| // grad_fn->dsts correspond to input tensors during forward | |||||
| // calculation, grad_fn->slots correspond to output tensors. | |||||
| // condition true means the output tensor has gradient for | |||||
| // back-propagation | |||||
| if (grad_fn->backward_graph->save_for_backward[grad_fn->dsts.size() + grad_fn->slots.size() + i]) { | |||||
| grad_fn->closure.push_back(std::move(grad_fn->slots[i].grad)); | |||||
| } | |||||
| } | |||||
| ApplyContext ctx; | |||||
| ctx.op = grad_fn->backward_graph->backward; | |||||
| ctx.flags = 0; | |||||
| ctx.nargs = grad_fn->closure.size(); | |||||
| Tensor* args[ctx.nargs]; | |||||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
| args[i] = grad_fn->closure[i].get(); | |||||
| mgb_assert(args[i]); | |||||
| ctx.flags |= args[i]->m_flags; | |||||
| } | |||||
| ctx.args = args; | |||||
| auto grads = apply(ctx); | |||||
| size_t j = 0; | |||||
| for (size_t i = 0; i < grad_fn->dsts.size(); ++i) { | |||||
| if (grad_fn->backward_graph->input_has_grad[i]) { | |||||
| auto& dst = grad_fn->dsts[i]; | |||||
| // grads[j] is consumed in accum_grad | |||||
| accum_grad(dst->grad, std::move(grads[j])); | |||||
| ++j; | |||||
| } | |||||
| } | |||||
| mgb_assert(j == grads.size()); | |||||
| } | |||||
| for (auto&& dst : grad_fn->dsts) { | |||||
| if (!dst.grad_fn) continue; | |||||
| if (!dst.grad_fn->in_ref_keeper) { | |||||
| dst.grad_fn->in_ref_keeper = true; | |||||
| ref_keeper.push_back(dst.grad_fn); | |||||
| } | |||||
| // grad_fn->clear will unlink current dst.producer_record | |||||
| // such that if dst.producer_record.next is false, dst accumulates | |||||
| // all the gradients | |||||
| if (!dst.producer_record.next && dst->callback && dst->grad) { | |||||
| dst->callback(TensorWrapper::make(pytype, dst->grad)); | |||||
| } | |||||
| } | |||||
| grad_fn->clear(); | |||||
| } // finish tape loop | |||||
| } | |||||
| void GradKey::cleanup() { | |||||
| active = false; | |||||
| tape.clear(); | |||||
| for (intrusive_list::Iterator it(free_vars_head); it;) { | |||||
| it->grad_fn.reset(); | |||||
| (it++)->unlink(); | |||||
| } | |||||
| } | |||||
| void GradKeyWrapper::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) { | |||||
| m_key->backward(std::move(tensors), std::move(grads)); | |||||
| } | |||||
| GradKey::~GradKey() { | |||||
| cleanup(); | |||||
| } | |||||
| } // namespace mgb::imperative::python | |||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * \file imperative/python/src/grad.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "./tensor.h" | |||||
| #include <megbrain/utils/small_vector.h> | |||||
| #include <memory> | |||||
| namespace mgb::imperative::python { | |||||
| apply_result_t apply_grad(ApplyContext& ctx); | |||||
| struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj { | |||||
| std::string name; | |||||
| bool active = true; | |||||
| GradInfo::head_t free_vars_head; | |||||
| std::vector<std::weak_ptr<GradFn>> tape; | |||||
| ~GradKey(); | |||||
| void attach(Tensor* tensor, pybind11::object callback); | |||||
| void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | |||||
| void cleanup(); | |||||
| }; | |||||
| struct GradKeyWrapper { | |||||
| using wrap_t = pyext17::wrap<GradKeyWrapper>; | |||||
| static constexpr auto tp_name = pybind11::detail::_("GradKey"); | |||||
| std::shared_ptr<GradKey> m_key; | |||||
| inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {} | |||||
| void attach(PyObject*const* args, size_t nargs); | |||||
| void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | |||||
| }; | |||||
| } // namespace mgb::imperative::python | |||||
| namespace pybind11::detail { | |||||
| template<> struct type_caster<mgb::imperative::python::GradKeyWrapper> : mgb::imperative::python::GradKeyWrapper::wrap_t::caster {}; | |||||
| } // namespace pybind11::detail | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * \file imperative/python/src/grad_info.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "./intrusive_list.h" | |||||
| namespace mgb::imperative::python { | |||||
| struct GradFn; | |||||
| struct GradSlot; | |||||
| struct GradSlotPtr { | |||||
| std::shared_ptr<GradFn> grad_fn; | |||||
| size_t idx; | |||||
| GradSlot* operator->(); | |||||
| }; | |||||
| struct GradInfo : GradSlotPtr, intrusive_list::Node<GradInfo, intrusive_list::before_t> { | |||||
| GradInfo() = default; | |||||
| GradInfo(GradInfo&) = default; | |||||
| GradInfo(GradInfo&&) = default; | |||||
| GradInfo& operator=(GradInfo&) = default; | |||||
| GradInfo& operator=(GradInfo&&) = default; | |||||
| }; | |||||
| } // namespace mgb::imperative::python | |||||
| @@ -0,0 +1,227 @@ | |||||
| /** | |||||
| * \file imperative/python/src/intrusive_list.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megbrain/utils/metahelper.h" | |||||
| namespace mgb::imperative::python::intrusive_list { | |||||
| // copy policy | |||||
| struct after_t {}; | |||||
| struct before_t {}; | |||||
| struct disable_t {}; | |||||
| template <typename T> struct Tail; | |||||
| // invariant: next->prev == this | |||||
| template <typename T> | |||||
| struct Head { | |||||
| Tail<T>* next; | |||||
| Head(Tail<T>* node = nullptr) : next(node) {} | |||||
| Head(const Head<T>&) = delete; | |||||
| Head<T>& operator=(const Head<T>&) = delete; | |||||
| Head(Head<T>&& rhs) : next(rhs.next) { | |||||
| rhs.next = nullptr; | |||||
| if (next) { | |||||
| next->prev = this; | |||||
| } | |||||
| } | |||||
| Head<T>& operator=(Head<T>&& rhs) { | |||||
| mgb_assert(!next); | |||||
| next = rhs.next; | |||||
| rhs.next = nullptr; | |||||
| if (next) { | |||||
| next->prev = this; | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| ~Head() { | |||||
| if (next) { | |||||
| next->prev = nullptr; | |||||
| } | |||||
| } | |||||
| }; | |||||
| // invariant: prev->next == this | |||||
| template <typename T> | |||||
| struct Tail { | |||||
| Head<T>* prev; | |||||
| Tail(Head<T>* node = nullptr) : prev(node) {} | |||||
| Tail(const Tail<T>&) = delete; | |||||
| Tail<T>& operator=(const Tail<T>&) = delete; | |||||
| Tail(Tail<T>&& rhs) : prev(rhs.prev) { | |||||
| rhs.prev = nullptr; | |||||
| if (prev) { | |||||
| prev->next = this; | |||||
| } | |||||
| } | |||||
| Tail<T>& operator=(Tail<T>&& rhs) { | |||||
| mgb_assert(!prev); | |||||
| prev = rhs.prev; | |||||
| rhs.prev = nullptr; | |||||
| if (prev) { | |||||
| prev->next = this; | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| ~Tail() { | |||||
| if (prev) { | |||||
| prev->next = nullptr; | |||||
| } | |||||
| } | |||||
| }; | |||||
| template <typename T, typename policy> struct Node; | |||||
| template <typename T> | |||||
| class Iterator { | |||||
| T* ptr; | |||||
| void inc() {ptr = static_cast<T*>(ptr->Head<T>::next);} | |||||
| void dec() {ptr = static_cast<T*>(ptr->Head<T>::prev);} | |||||
| public: | |||||
| Iterator(Head<T>& head) : ptr(static_cast<T*>(head.next)) {} | |||||
| Iterator(Tail<T>& tail) : ptr(static_cast<T*>(tail.prev)) {} | |||||
| template<typename policy> | |||||
| Iterator(Node<T, policy>& node) : ptr(static_cast<T*>(&node)) {} | |||||
| T& operator*() {return *static_cast<T*>(ptr);} | |||||
| T* operator->() {return static_cast<T*>(ptr);} | |||||
| operator bool() {return ptr;} | |||||
| bool operator==(const Iterator<T>& rhs) {return ptr == rhs.ptr;} | |||||
| Iterator& operator++() {inc(); return *this;} | |||||
| Iterator& operator--() {dec(); return *this;} | |||||
| Iterator operator++(int) {auto ret = *this; inc(); return ret;} | |||||
| Iterator operator--(int) {auto ret = *this; dec(); return ret;} | |||||
| }; | |||||
| // Node in a doubly linked list. Unlike std::list, nodes are not owned by a container. | |||||
| // Instead, nodes may join or leave a list freely. | |||||
| // NOTE: Derived classes have to explicitly declare copy / assignment as default, | |||||
| // otherwise the compiler generated version would use the const T& signature, | |||||
| // which is deleted. | |||||
| template <typename T = void, typename policy = disable_t> | |||||
| struct Node : Tail<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>>, | |||||
| Head<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>> { | |||||
| private: | |||||
| using this_t = Node<T, policy>; | |||||
| using U = std::conditional_t<std::is_same_v<T, void>, this_t, T>; | |||||
| public: | |||||
| using head_t = Head<U>; | |||||
| using tail_t = Tail<U>; | |||||
| using head_t::next; | |||||
| using tail_t::prev; | |||||
| Node() = default; | |||||
| Node(const this_t&) = delete; | |||||
| this_t& operator=(const this_t&) = delete; | |||||
| //! constructed node is inserted after the input node | |||||
| Node(after_t, head_t& node) : tail_t(&node), head_t(node.next) { | |||||
| node.next = this; | |||||
| if (next) { | |||||
| next->prev = this; | |||||
| } | |||||
| } | |||||
| //! constructed node is inserted before the input node | |||||
| Node(before_t, tail_t& node) : head_t(&node), tail_t(node.prev) { | |||||
| node.prev = this; | |||||
| if (prev) { | |||||
| prev->next = this; | |||||
| } | |||||
| } | |||||
| Node(this_t&& rhs) : tail_t(rhs.prev), head_t(rhs.next) { | |||||
| rhs.prev = nullptr; | |||||
| rhs.next = nullptr; | |||||
| if (prev) { | |||||
| prev->next = this; | |||||
| } | |||||
| if (next) { | |||||
| next->prev = this; | |||||
| } | |||||
| } | |||||
| Node& operator=(this_t&& rhs) { | |||||
| unlink(); | |||||
| prev = rhs.prev; | |||||
| next = rhs.next; | |||||
| rhs.prev = nullptr; | |||||
| rhs.next = nullptr; | |||||
| if (prev) { | |||||
| prev->next = this; | |||||
| } | |||||
| if (next) { | |||||
| next->prev = this; | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| template<typename p = policy, | |||||
| typename = std::enable_if_t<std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>> | |||||
| Node(this_t& rhs) : Node(policy{}, rhs) {} | |||||
| template<typename p = policy, | |||||
| typename = std::enable_if_t<std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>> | |||||
| this_t& operator=(this_t& rhs) { | |||||
| insert(policy{}, rhs); | |||||
| return *this; | |||||
| } | |||||
| void unlink() { | |||||
| if (prev) { | |||||
| prev->next = next; | |||||
| } | |||||
| if (next) { | |||||
| next->prev = prev; | |||||
| } | |||||
| prev = nullptr; | |||||
| next = nullptr; | |||||
| } | |||||
| //! this node is unlinked from its list and inserted after the input node | |||||
| void insert(after_t, head_t& node) { | |||||
| unlink(); | |||||
| prev = &node; | |||||
| next = node.next; | |||||
| node.next = this; | |||||
| if (next) { | |||||
| next->prev = this; | |||||
| } | |||||
| } | |||||
| //! this node is unlinked from its list and inserted before the input node | |||||
| void insert(before_t, tail_t& node) { | |||||
| unlink(); | |||||
| next = &node; | |||||
| prev = node.prev; | |||||
| node.prev = this; | |||||
| if (prev) { | |||||
| prev->next = this; | |||||
| } | |||||
| } | |||||
| void insert_before(tail_t& node) {insert(before_t{}, node);} | |||||
| void insert_after(head_t& node) {insert(after_t{}, node);} | |||||
| ~Node() { | |||||
| unlink(); | |||||
| } | |||||
| }; | |||||
| } // namespace mgb::imperative::python::intrusive_list | |||||
| @@ -23,7 +23,10 @@ | |||||
| #include "./dispatcher.h" | #include "./dispatcher.h" | ||||
| #include "./tensor.h" | |||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| using namespace mgb::imperative::python; | |||||
| #ifndef MODULE_NAME | #ifndef MODULE_NAME | ||||
| #define MODULE_NAME imperative_rt | #define MODULE_NAME imperative_rt | ||||
| @@ -68,4 +71,6 @@ PYBIND11_MODULE(MODULE_NAME, m) { | |||||
| py::getattr(m, "__dict__")); | py::getattr(m, "__dict__")); | ||||
| init_dispatcher(submodule(m, "dispatcher")); | init_dispatcher(submodule(m, "dispatcher")); | ||||
| init_tensor(submodule(m, "core2")); | |||||
| } | } | ||||
| @@ -15,6 +15,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <utility> | #include <utility> | ||||
| #include <Python.h> | #include <Python.h> | ||||
| #include <pybind11/pybind11.h> | |||||
| namespace pyext17 { | namespace pyext17 { | ||||
| @@ -53,6 +54,26 @@ inline PyObject* cvt_retval(PyObject* rv) { | |||||
| return cvt_retval(__VA_ARGS__); \ | return cvt_retval(__VA_ARGS__); \ | ||||
| } | } | ||||
| inline int cvt_retint(int ret) { | |||||
| return ret; | |||||
| } | |||||
| #define CVT_RET_INT(...) \ | |||||
| if constexpr (std::is_same_v<decltype(__VA_ARGS__), void>) { \ | |||||
| __VA_ARGS__; \ | |||||
| return 0; \ | |||||
| } else { \ | |||||
| return cvt_retint(__VA_ARGS__); \ | |||||
| } | |||||
| struct py_err_set : std::exception {}; | |||||
| #define HANDLE_ALL_EXC(RET) catch(py_err_set&) {return RET;} \ | |||||
| catch(pybind11::error_already_set& e) {e.restore(); return RET;} \ | |||||
| catch(pybind11::builtin_exception& e) {e.set_error(); return RET;} \ | |||||
| catch(std::exception& e) {PyErr_SetString(PyExc_RuntimeError, e.what()); return RET;} | |||||
| template <typename T> | template <typename T> | ||||
| struct wrap { | struct wrap { | ||||
| private: | private: | ||||
| @@ -111,7 +132,9 @@ private: | |||||
| static PyObject* impl(PyObject* self, PyObject*) { | static PyObject* impl(PyObject* self, PyObject*) { | ||||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | ||||
| CVT_RET_PYOBJ((inst->*f)()); | |||||
| try { | |||||
| CVT_RET_PYOBJ((inst->*f)()); | |||||
| } HANDLE_ALL_EXC(nullptr) | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -121,7 +144,9 @@ private: | |||||
| static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { | static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { | ||||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | ||||
| CVT_RET_PYOBJ((inst->*f)(args, kwargs)); | |||||
| try { | |||||
| CVT_RET_PYOBJ((inst->*f)(args, kwargs)); | |||||
| } HANDLE_ALL_EXC(nullptr) | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -132,7 +157,9 @@ private: | |||||
| static PyObject* impl(PyObject* self, PyObject*const* args, Py_ssize_t nargs) { | static PyObject* impl(PyObject* self, PyObject*const* args, Py_ssize_t nargs) { | ||||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | ||||
| CVT_RET_PYOBJ((inst->*f)(args, nargs)); | |||||
| try { | |||||
| CVT_RET_PYOBJ((inst->*f)(args, nargs)); | |||||
| } HANDLE_ALL_EXC(nullptr) | |||||
| } | } | ||||
| #else | #else | ||||
| static constexpr int flags = METH_VARARGS; | static constexpr int flags = METH_VARARGS; | ||||
| @@ -141,7 +168,9 @@ private: | |||||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | ||||
| auto* arr = &PyTuple_GET_ITEM(args, 0); | auto* arr = &PyTuple_GET_ITEM(args, 0); | ||||
| auto size = PyTuple_GET_SIZE(args); | auto size = PyTuple_GET_SIZE(args); | ||||
| CVT_RET_PYOBJ((inst->*f)(arr, size)); | |||||
| try { | |||||
| CVT_RET_PYOBJ((inst->*f)(arr, size)); | |||||
| } HANDLE_ALL_EXC(nullptr) | |||||
| } | } | ||||
| #endif | #endif | ||||
| }; | }; | ||||
| @@ -152,7 +181,9 @@ private: | |||||
| static PyObject* impl(PyObject* self, PyObject* obj) { | static PyObject* impl(PyObject* self, PyObject* obj) { | ||||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | ||||
| CVT_RET_PYOBJ((inst->*f)(obj)); | |||||
| try { | |||||
| CVT_RET_PYOBJ((inst->*f)(obj)); | |||||
| } HANDLE_ALL_EXC(nullptr) | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -162,6 +193,55 @@ private: | |||||
| return {name, (PyCFunction)M::impl, M::flags, doc}; | return {name, (PyCFunction)M::impl, M::flags, doc}; | ||||
| } | } | ||||
| template<auto f> | |||||
| struct getter { | |||||
| using F = decltype(f); | |||||
| static PyObject* impl(PyObject* self, void* closure) { | |||||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
| try { | |||||
| if constexpr (std::is_invocable_v<F, PyObject*, void*>) { | |||||
| CVT_RET_PYOBJ(f(self, closure)); | |||||
| } else if constexpr (std::is_invocable_v<F, T, void*>) { | |||||
| CVT_RET_PYOBJ((inst->*f)(closure)); | |||||
| } else if constexpr (std::is_invocable_v<F, T>) { | |||||
| CVT_RET_PYOBJ((inst->*f)()); | |||||
| } else { | |||||
| static_assert(!std::is_same_v<F, F>); | |||||
| } | |||||
| } HANDLE_ALL_EXC(nullptr) | |||||
| } | |||||
| }; | |||||
| template<auto f> | |||||
| struct setter { | |||||
| using F = decltype(f); | |||||
| template<typename = void> | |||||
| static int impl_(PyObject* self, PyObject* val, void* closure) { | |||||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
| try { | |||||
| if constexpr (std::is_invocable_v<F, PyObject*, PyObject*, void*>) { | |||||
| CVT_RET_INT(f(self, val, closure)); | |||||
| } else if constexpr (std::is_invocable_v<F, T, PyObject*, void*>) { | |||||
| CVT_RET_INT((inst->*f)(val, closure)); | |||||
| } else if constexpr (std::is_invocable_v<F, T, PyObject*>) { | |||||
| CVT_RET_INT((inst->*f)(val)); | |||||
| } else { | |||||
| static_assert(!std::is_same_v<F, F>); | |||||
| } | |||||
| } HANDLE_ALL_EXC(-1) | |||||
| } | |||||
| static constexpr auto impl = []() {if constexpr (std::is_same_v<F, std::nullptr_t>) return nullptr; | |||||
| else return impl_<>;}(); | |||||
| }; | |||||
| template<auto get, auto set = nullptr> | |||||
| static constexpr PyGetSetDef make_getset_def(const char* name, const char* doc = nullptr, void* closure = nullptr) { | |||||
| return {const_cast<char *>(name), getter<get>::impl, setter<set>::impl, const_cast<char *>(doc), closure}; | |||||
| } | |||||
| // polyfills | // polyfills | ||||
| struct tp_vectorcall { | struct tp_vectorcall { | ||||
| @@ -216,16 +296,26 @@ private: | |||||
| template<typename = void> | template<typename = void> | ||||
| static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) { | static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) { | ||||
| struct FreeGuard { | |||||
| PyObject* self; | |||||
| PyTypeObject* type; | |||||
| ~FreeGuard() {if (self) type->tp_free(self);} | |||||
| }; | |||||
| auto* self = type->tp_alloc(type, 0); | auto* self = type->tp_alloc(type, 0); | ||||
| FreeGuard free_guard{self, type}; | |||||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | ||||
| if constexpr (has_vectorcall && tp_vectorcall::valid) { | if constexpr (has_vectorcall && tp_vectorcall::valid) { | ||||
| reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>; | reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>; | ||||
| } | } | ||||
| if constexpr (varkw) { | |||||
| new(inst) T(args, kwargs); | |||||
| } else { | |||||
| new(inst) T(); | |||||
| } | |||||
| try { | |||||
| if constexpr (varkw) { | |||||
| new(inst) T(args, kwargs); | |||||
| } else { | |||||
| new(inst) T(); | |||||
| } | |||||
| } HANDLE_ALL_EXC(nullptr) | |||||
| free_guard.self = nullptr; | |||||
| return self; | return self; | ||||
| } | } | ||||
| @@ -250,6 +340,7 @@ private: | |||||
| public: | public: | ||||
| class TypeBuilder { | class TypeBuilder { | ||||
| std::vector<PyMethodDef> m_methods; | std::vector<PyMethodDef> m_methods; | ||||
| std::vector<PyGetSetDef> m_getsets; | |||||
| PyTypeObject m_type; | PyTypeObject m_type; | ||||
| bool m_finalized = false; | bool m_finalized = false; | ||||
| bool m_ready = false; | bool m_ready = false; | ||||
| @@ -259,6 +350,13 @@ public: | |||||
| throw std::runtime_error("type is already finalized"); | throw std::runtime_error("type is already finalized"); | ||||
| } | } | ||||
| } | } | ||||
| static const char* to_c_str(const char* s) {return s;} | |||||
| template <size_t N, typename... Ts> | |||||
| static const char* to_c_str(const pybind11::detail::descr<N, Ts...>& desc) { | |||||
| return desc.text; | |||||
| } | |||||
| public: | public: | ||||
| TypeBuilder(const TypeBuilder&) = delete; | TypeBuilder(const TypeBuilder&) = delete; | ||||
| TypeBuilder& operator=(const TypeBuilder&) = delete; | TypeBuilder& operator=(const TypeBuilder&) = delete; | ||||
| @@ -266,7 +364,7 @@ public: | |||||
| TypeBuilder() : m_type{PyVarObject_HEAD_INIT(nullptr, 0)} { | TypeBuilder() : m_type{PyVarObject_HEAD_INIT(nullptr, 0)} { | ||||
| constexpr auto has_tp_name = HAS_MEMBER(T, tp_name); | constexpr auto has_tp_name = HAS_MEMBER(T, tp_name); | ||||
| if constexpr (has_tp_name) { | if constexpr (has_tp_name) { | ||||
| m_type.tp_name = T::tp_name; | |||||
| m_type.tp_name = to_c_str(T::tp_name); | |||||
| } | } | ||||
| m_type.tp_dealloc = tp_dealloc::value; | m_type.tp_dealloc = tp_dealloc::value; | ||||
| #ifdef _Py_TPFLAGS_HAVE_VECTORCALL | #ifdef _Py_TPFLAGS_HAVE_VECTORCALL | ||||
| @@ -291,8 +389,17 @@ public: | |||||
| return m_ready; | return m_ready; | ||||
| } | } | ||||
| bool isinstance(PyObject* op) { | |||||
| return PyObject_TypeCheck(op, &m_type); | |||||
| } | |||||
| bool isexact(PyObject* op) { | |||||
| return Py_TYPE(op) == &m_type; | |||||
| } | |||||
| PyObject* finalize() { | PyObject* finalize() { | ||||
| if (!m_finalized) { | if (!m_finalized) { | ||||
| m_finalized = true; | |||||
| if (m_methods.size()) { | if (m_methods.size()) { | ||||
| m_methods.push_back({0}); | m_methods.push_back({0}); | ||||
| if (m_type.tp_methods) { | if (m_type.tp_methods) { | ||||
| @@ -301,6 +408,14 @@ public: | |||||
| } | } | ||||
| m_type.tp_methods = &m_methods[0]; | m_type.tp_methods = &m_methods[0]; | ||||
| } | } | ||||
| if (m_getsets.size()) { | |||||
| m_getsets.push_back({0}); | |||||
| if (m_type.tp_getset) { | |||||
| PyErr_SetString(PyExc_SystemError, "tp_getset is already set"); | |||||
| return nullptr; | |||||
| } | |||||
| m_type.tp_getset = &m_getsets[0]; | |||||
| } | |||||
| if (PyType_Ready(&m_type)) { | if (PyType_Ready(&m_type)) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -315,12 +430,64 @@ public: | |||||
| m_methods.push_back(make_meth_def<f>(name, doc)); | m_methods.push_back(make_meth_def<f>(name, doc)); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| template<auto get, auto set = nullptr> | |||||
| TypeBuilder& def_getset(const char* name, const char* doc = nullptr, void* closure = nullptr) { | |||||
| check_finalized(); | |||||
| m_getsets.push_back(make_getset_def<get, set>(name, doc, closure)); | |||||
| return *this; | |||||
| } | |||||
| }; | }; | ||||
| static TypeBuilder& type() { | static TypeBuilder& type() { | ||||
| static TypeBuilder type_helper; | static TypeBuilder type_helper; | ||||
| return type_helper; | return type_helper; | ||||
| } | } | ||||
| template<typename... Args> | |||||
| static PyObject* cnew(Args&&... args) { | |||||
| auto* pytype = type().operator->(); | |||||
| auto* self = pytype->tp_alloc(pytype, 0); | |||||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
| if constexpr (has_vectorcall && tp_vectorcall::valid) { | |||||
| reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>; | |||||
| } | |||||
| new(inst) T(std::forward<Args>(args)...); | |||||
| return self; | |||||
| } | |||||
| template<typename... Args> | |||||
| static PyObject* cnew_with_type(PyTypeObject* pytype, Args&&... args) { | |||||
| auto* self = pytype->tp_alloc(pytype, 0); | |||||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
| if constexpr (has_vectorcall && tp_vectorcall::valid) { | |||||
| reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>; | |||||
| } | |||||
| new(inst) T(std::forward<Args>(args)...); | |||||
| return self; | |||||
| } | |||||
| struct caster { | |||||
| static constexpr auto name = T::tp_name; | |||||
| T* value; | |||||
| bool load(pybind11::handle src, bool convert) { | |||||
| if (wrap_t::type().isinstance(src.ptr())) { | |||||
| value = reinterpret_cast<wrap_t*>(src.ptr())->inst(); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| template <typename U> using cast_op_type = pybind11::detail::cast_op_type<U>; | |||||
| operator T*() { return value; } | |||||
| operator T&() { return *value; } | |||||
| }; | |||||
| }; | }; | ||||
| } // namespace pyext17 | } // namespace pyext17 | ||||
| @@ -328,3 +495,5 @@ public: | |||||
| #undef HAS_MEMBER_TYPE | #undef HAS_MEMBER_TYPE | ||||
| #undef HAS_MEMBER | #undef HAS_MEMBER | ||||
| #undef CVT_RET_PYOBJ | #undef CVT_RET_PYOBJ | ||||
| #undef CVT_RET_INT | |||||
| #undef HANDLE_ALL_EXC | |||||
| @@ -0,0 +1,257 @@ | |||||
| /** | |||||
| * \file imperative/python/src/tensor.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "./tensor.h" | |||||
| #include "./grad.h" | |||||
| #include "./common.h" | |||||
| #include "./numpy_dtypes.h" | |||||
| #include <pybind11/numpy.h> | |||||
| #include <pybind11/operators.h> | |||||
| namespace py = pybind11; | |||||
| namespace mgb::imperative::python { | |||||
| std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | |||||
| apply_result_t apply(ApplyContext& ctx) { | |||||
| // emulating scalar should be put to specific op's apply, e.g., | |||||
| // elementwise, reduce, typecvt. Currently it's still handled at python | |||||
| // side. It could be move to C++ side if it has an impact on performance | |||||
| if (ctx.flags & Tensor::Flags::SCALAR) { | |||||
| // TODO: emulate scalar | |||||
| } | |||||
| if (ctx.flags & Tensor::Flags::GRAD) { | |||||
| return apply_grad(ctx); | |||||
| } | |||||
| if (ctx.flags & Tensor::Flags::TRACE) { | |||||
| // TODO: trace | |||||
| } else { | |||||
| SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs); | |||||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
| handles[i] = ctx.args[i]->m_handle.get(); | |||||
| } | |||||
| auto output_handles = interpreter_for_py->apply_op(ctx.op, handles); | |||||
| apply_result_t outputs; | |||||
| outputs.reserve(output_handles.size()); | |||||
| for (auto h : output_handles) { | |||||
| outputs.emplace_back(std::make_shared<Tensor>(h)); | |||||
| } | |||||
| return outputs; | |||||
| } | |||||
| mgb_assert(0); | |||||
| } | |||||
| PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */) { | |||||
| try { | |||||
| // if (kwnames && PyTuple_GET_SIZE(kwnames)) { | |||||
| // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed"); | |||||
| // return nullptr; | |||||
| // } | |||||
| if (!nargs) { | |||||
| PyErr_SetString(PyExc_TypeError, "expect Op"); | |||||
| return nullptr; | |||||
| } | |||||
| auto* op = args[0]; | |||||
| if (!strcmp(op->ob_type->tp_base->tp_name,"PodOpVisitor") || !strcmp(op->ob_type->tp_base->tp_name,"IndexingOpBase")){ | |||||
| op = PyObject_CallMethod(op,"to_c",""); | |||||
| } | |||||
| PyTypeObject* pytype = args[1]->ob_type; | |||||
| ++args; | |||||
| --nargs; | |||||
| ApplyContext ctx; | |||||
| ctx.flags = 0; | |||||
| ctx.op = py::handle(op).cast<std::shared_ptr<OpDef>>(); | |||||
| SmallVector<Tensor*, 64> tensors(nargs); | |||||
| ctx.args = &tensors[0]; | |||||
| ctx.nargs = nargs; | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| TensorWrapper* tw = TensorWrapper::cast_safe(args[i]); | |||||
| if (!tw) { | |||||
| PyErr_SetString(PyExc_TypeError, "expect Tensor"); | |||||
| return nullptr; | |||||
| } | |||||
| auto* t = tensors[i] = tw->m_tensor.get(); | |||||
| ctx.flags |= t->m_flags; | |||||
| } | |||||
| // TODO: set TRACE flag | |||||
| auto outputs = apply(ctx); | |||||
| size_t nout = outputs.size(); | |||||
| auto ret = py::tuple(nout); | |||||
| for (size_t i = 0; i < nout; ++i) { | |||||
| ret[i] = TensorWrapper::make(pytype, std::move(outputs[i])); | |||||
| } | |||||
| return ret.release().ptr(); | |||||
| } catch (std::exception& e) { | |||||
| PyErr_SetString(PyExc_RuntimeError, e.what()); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
| if (kwargs && PyDict_Size(kwargs)) { | |||||
| throw py::type_error("keyword argument not allowed"); | |||||
| } | |||||
| auto nargs = PyTuple_Size(args); | |||||
| auto tup = py::reinterpret_borrow<py::tuple>(args); | |||||
| if (nargs == 0) { | |||||
| throw py::type_error("too few arguments"); | |||||
| } | |||||
| if (auto* t = cast_safe(tup[0].ptr())) { | |||||
| if (nargs > 1) { | |||||
| throw py::type_error("expect 1 argument"); | |||||
| } | |||||
| m_tensor = t->m_tensor; | |||||
| } else { | |||||
| if (nargs != 3) { | |||||
| throw py::type_error("expect 3 arguments"); | |||||
| } | |||||
| py::detail::loader_life_support life_sup; // required to cast DType | |||||
| auto data = tup[0].cast<py::array>(); | |||||
| DType dtype = tup[1].cast<DType>(); | |||||
| CompNode cn = tup[2].cast<CompNode>(); | |||||
| interpreter::Interpreter::Handle handle; | |||||
| constexpr auto size_threshhold = TensorShape::MAX_NDIM; | |||||
| if (data.size() > size_threshhold) { | |||||
| handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype)); | |||||
| } else { | |||||
| HostTensorND ret(cn); | |||||
| handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); | |||||
| } | |||||
| m_tensor = std::make_shared<Tensor>(handle); | |||||
| if (data.ndim() == 0) { | |||||
| m_tensor->m_flags |= Tensor::Flags::SCALAR; | |||||
| } | |||||
| } | |||||
| } | |||||
| PyObject* TensorWrapper::shape() { | |||||
| if (m_tensor->m_flags & Tensor::Flags::SCALAR) { | |||||
| return PyTuple_New(0); | |||||
| } | |||||
| auto&& shape = m_tensor->shape(); | |||||
| if (!shape.ndim) { | |||||
| Py_RETURN_NONE; | |||||
| } | |||||
| py::tuple ret(shape.ndim); | |||||
| for (size_t i = 0; i < shape.ndim; ++i) { | |||||
| ret[i] = shape[i]; | |||||
| } | |||||
| return ret.release().ptr(); | |||||
| } | |||||
| PyObject* TensorWrapper::dtype() { | |||||
| return py::cast(m_tensor->dtype()).release().ptr(); | |||||
| } | |||||
| PyObject* TensorWrapper::device() { | |||||
| return py::cast(m_tensor->comp_node()).release().ptr(); | |||||
| } | |||||
| PyObject* TensorWrapper::numpy() { | |||||
| auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get()); | |||||
| auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); | |||||
| if (!arr) return nullptr; | |||||
| if (m_tensor->m_flags & Tensor::Flags::SCALAR) { | |||||
| mgb_assert(PyArray_Check(arr.ptr())); | |||||
| return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr())); | |||||
| } | |||||
| return arr.release().ptr(); | |||||
| } | |||||
| void TensorWrapper::reset(PyObject* tensor) { | |||||
| TensorWrapper* t = TensorWrapper::cast_safe(tensor); | |||||
| if (!t) { | |||||
| throw py::type_error("expect Tensor"); | |||||
| } | |||||
| m_tensor = t->m_tensor; | |||||
| } | |||||
| PyObject* TensorWrapper::isscalar() { | |||||
| if(m_tensor->m_flags & Tensor::Flags::SCALAR) { | |||||
| Py_RETURN_TRUE; | |||||
| } else { | |||||
| Py_RETURN_FALSE; | |||||
| } | |||||
| } | |||||
| void TensorWrapper::setscalar() { | |||||
| m_tensor->m_flags |= Tensor::Flags::SCALAR; | |||||
| } | |||||
| struct TensorWeakRef { | |||||
| std::weak_ptr<Tensor> wptr; | |||||
| TensorWeakRef(const TensorWrapper& tw) : wptr(tw.m_tensor) {} | |||||
| py::object operator()() { | |||||
| if (auto p = wptr.lock()) { | |||||
| return TensorWrapper::make(p); | |||||
| } | |||||
| return py::none(); | |||||
| } | |||||
| }; | |||||
| void init_tensor(py::module m) { | |||||
| interpreter_for_py = interpreter::Interpreter::inst().create_channel(); | |||||
| auto* tensor_type = TensorWrapper::wrap_t::type() | |||||
| .def<&TensorWrapper::numpy>("numpy") | |||||
| .def_getset<&TensorWrapper::shape>("shape") | |||||
| .def_getset<&TensorWrapper::dtype>("dtype") | |||||
| .def_getset<&TensorWrapper::device>("device") | |||||
| .def<&TensorWrapper::reset>("_reset") | |||||
| .def<&TensorWrapper::isscalar>("isscalar") | |||||
| .def<&TensorWrapper::setscalar>("setscalar") | |||||
| .finalize(); | |||||
| if (!tensor_type) throw py::error_already_set(); | |||||
| py::setattr(m, "Tensor", tensor_type); | |||||
| py::class_<TensorWeakRef>(m, "TensorWeakRef") | |||||
| .def(py::init<const TensorWrapper&>()) | |||||
| .def("__call__", &TensorWeakRef::operator()); | |||||
| static PyMethodDef apply_def{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr}; | |||||
| auto* apply_func = PyCFunction_NewEx(&apply_def, nullptr, nullptr); | |||||
| if (!apply_func) throw py::error_already_set(); | |||||
| py::setattr(m, "apply", apply_func); | |||||
| py::handle grad_key_type = GradKeyWrapper::wrap_t::type() | |||||
| .def<&GradKeyWrapper::attach>("attach") | |||||
| .finalize(); | |||||
| if (!grad_key_type) throw py::error_already_set(); | |||||
| py::setattr(m, "GradKey", grad_key_type); | |||||
| py::setattr(m, "backward", py::cpp_function(&GradKeyWrapper::backward)); | |||||
| } | |||||
| } // namespace mgb::imperative::python | |||||
| @@ -0,0 +1,157 @@ | |||||
| /** | |||||
| * \file imperative/python/src/tensor.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <variant> | |||||
| #include "megbrain/imperative/interpreter.h" | |||||
| #include "pybind11/pybind11.h" | |||||
| #include "./pyext17.h" | |||||
| namespace mgb::imperative::python { | |||||
| template<typename T, typename B = pybind11::object> | |||||
| struct ObjectPtr : B { | |||||
| using B::B; | |||||
| T& operator*() {return reinterpret_cast<T&>(*B::ptr());} | |||||
| T* operator->() {return reinterpret_cast<T*>(B::ptr());} | |||||
| }; | |||||
| } // namespace mgb::imperative::python | |||||
| #include "./grad_info.h" // for struct GradInfo | |||||
| namespace mgb::imperative::python { | |||||
| struct TraceInfo { | |||||
| }; | |||||
| extern std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | |||||
| class SharedHandle { | |||||
| using Handle = interpreter::Interpreter::Handle; | |||||
| static_assert(std::is_pointer_v<Handle>); | |||||
| std::shared_ptr<std::remove_pointer_t<Handle>> holder; | |||||
| public: | |||||
| inline explicit SharedHandle(Handle handle) : holder(handle, [](auto* h){ | |||||
| interpreter_for_py->del(h); | |||||
| }) {} | |||||
| SharedHandle(const SharedHandle&) = default; | |||||
| SharedHandle& operator=(const SharedHandle&) = default; | |||||
| SharedHandle(SharedHandle&&) = default; | |||||
| SharedHandle& operator=(SharedHandle&&) = default; | |||||
| inline Handle get() {return holder.get();} | |||||
| }; | |||||
| struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||||
| using flags_t = uint64_t; | |||||
| struct Flags { | |||||
| static constexpr flags_t SCALAR = 1; | |||||
| static constexpr flags_t GRAD = 1 << 1; | |||||
| static constexpr flags_t TRACE = 1 << 2; | |||||
| }; | |||||
| flags_t m_flags = 0; | |||||
| GradInfo m_grad_info; | |||||
| TraceInfo m_trace_info; | |||||
| SharedHandle m_handle; | |||||
| using Handle = interpreter::Interpreter::Handle; | |||||
| inline explicit Tensor(Handle handle) : m_handle(handle) {} | |||||
| inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)) {} | |||||
| ~Tensor() = default; | |||||
| inline std::shared_ptr<Tensor> copy() { | |||||
| auto ret = std::make_shared<Tensor>(m_handle); | |||||
| ret->m_flags = m_flags; | |||||
| ret->m_grad_info = m_grad_info; | |||||
| ret->m_trace_info = m_trace_info; | |||||
| return ret; | |||||
| } | |||||
| inline DType dtype() {return interpreter_for_py->get_dtype(m_handle.get());} | |||||
| inline CompNode comp_node() {return interpreter_for_py->get_device(m_handle.get());} | |||||
| inline TensorShape shape() {return interpreter_for_py->get_shape(m_handle.get());} | |||||
| }; | |||||
| struct TensorWrapper { | |||||
| std::shared_ptr<Tensor> m_tensor; | |||||
| inline TensorWrapper(std::shared_ptr<Tensor> tensor = {}) : m_tensor(std::move(tensor)) {} | |||||
| TensorWrapper(PyObject* args, PyObject* kwargs); | |||||
| ~TensorWrapper() = default; | |||||
| static constexpr auto tp_name = pybind11::detail::_("Tensor"); | |||||
| using wrap_t = pyext17::wrap<TensorWrapper>; | |||||
| friend wrap_t; | |||||
| inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast<wrap_t*>(op)->inst();} | |||||
| inline static TensorWrapper* cast_safe(PyObject* op) { | |||||
| if (!wrap_t::type().isinstance(op)) return nullptr; | |||||
| return cast(op); | |||||
| } | |||||
| inline ObjectPtr<TensorWrapper, pybind11::handle> self() {return wrap_t::pycast(this);} | |||||
| template <typename... Args> | |||||
| static ObjectPtr<Tensor> make(Args&&... args) { | |||||
| auto* op = wrap_t::cnew(std::forward<Args>(args)...); | |||||
| return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op); | |||||
| } | |||||
| template <typename... Args> | |||||
| static ObjectPtr<Tensor> make(PyTypeObject* pytype, Args&&... args) { | |||||
| auto* op = wrap_t::cnew_with_type(pytype,std::forward<Args>(args)...); | |||||
| return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op); | |||||
| } | |||||
| PyObject* shape(); | |||||
| PyObject* dtype(); | |||||
| PyObject* device(); | |||||
| PyObject* numpy(); | |||||
| void reset(PyObject*); | |||||
| PyObject* isscalar(); | |||||
| void setscalar(); | |||||
| }; | |||||
| PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); | |||||
| struct ApplyContext { | |||||
| Tensor::flags_t flags; | |||||
| std::shared_ptr<OpDef> op; | |||||
| Tensor*const* args; | |||||
| size_t nargs; | |||||
| }; | |||||
| using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>; | |||||
| apply_result_t apply(ApplyContext& ctx); | |||||
| void init_tensor(pybind11::module); | |||||
| } // namespace mgb::imperative::python | |||||
| namespace pybind11::detail { | |||||
| template<> struct type_caster<mgb::imperative::python::TensorWrapper> : mgb::imperative::python::TensorWrapper::wrap_t::caster {}; | |||||
| } // namespace pybind11::detail | |||||
| @@ -0,0 +1,17 @@ | |||||
| /** | |||||
| * \file imperative/python/src/trace.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| namespace mgb::imperative::python { | |||||
| struct TraceInfo { | |||||
| }; | |||||
| } // namespace mgb::imperative::python | |||||
| @@ -12,6 +12,7 @@ def test_basic(): | |||||
| config_async_level(3) | config_async_level(3) | ||||
| @pytest.mark.skip | |||||
| def test_level1_infer_value(): | def test_level1_infer_value(): | ||||
| config_async_level(1) | config_async_level(1) | ||||
| a = mge.tensor([[1, 2], [2, 3], [3, 4]], dtype="float32") | a = mge.tensor([[1, 2], [2, 3], [3, 4]], dtype="float32") | ||||
| @@ -22,6 +23,7 @@ def test_level1_infer_value(): | |||||
| d = F.reshape(a, c) | d = F.reshape(a, c) | ||||
| @pytest.mark.skip | |||||
| def test_level1_infer_shape_with_unknown(): | def test_level1_infer_shape_with_unknown(): | ||||
| config_async_level(2) | config_async_level(2) | ||||
| a = mge.tensor([[1, 2, 2, 3]], dtype="float32") | a = mge.tensor([[1, 2, 2, 3]], dtype="float32") | ||||
| @@ -16,12 +16,11 @@ import pytest | |||||
| import megengine as mge | import megengine as mge | ||||
| import megengine.distributed as dist | import megengine.distributed as dist | ||||
| import megengine.functional as F | import megengine.functional as F | ||||
| from megengine.core._imperative_rt import TensorAttr, imperative | |||||
| from megengine.core._imperative_rt import TensorAttr, core2, imperative | |||||
| from megengine.core._imperative_rt.core2 import TensorWeakRef, apply | |||||
| from megengine.core._imperative_rt.imperative import sync | |||||
| from megengine.core.autodiff.grad import Grad | from megengine.core.autodiff.grad import Grad | ||||
| from megengine.core.ops.builtin import Elemwise | from megengine.core.ops.builtin import Elemwise | ||||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | |||||
| from megengine.core.tensor.tensor import Tensor, apply | |||||
| from megengine.core.tensor.tensor_wrapper import TensorWrapper | |||||
| from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
| from megengine.functional.distributed import remote_recv, remote_send | from megengine.functional.distributed import remote_recv, remote_send | ||||
| @@ -43,11 +42,11 @@ relu = _elwise(Elemwise.Mode.RELU) | |||||
| def as_tensor(x): | def as_tensor(x): | ||||
| return Tensor(as_raw_tensor(x, device=mge.device.get_default_device())) | |||||
| return mge.Tensor(x) | |||||
| def save_to(self, name="grad"): | def save_to(self, name="grad"): | ||||
| def callback(tensor, grad): | |||||
| def callback(grad): | |||||
| setattr(self, name, grad) | setattr(self, name, grad) | ||||
| return callback | return callback | ||||
| @@ -136,14 +135,14 @@ def test_2nd_grad(): | |||||
| def test_grad_with_tensor_wrapper(): | def test_grad_with_tensor_wrapper(): | ||||
| x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
| x = TensorWrapper(x_np) | |||||
| x = mge.Tensor(x_np) | |||||
| grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
| y = mul(x, x) | y = mul(x, x) | ||||
| y = mul(y, y) | y = mul(y, y) | ||||
| grad(y, TensorWrapper(np.ones_like(x_np))) | |||||
| grad(y, mge.Tensor(np.ones_like(x_np))) | |||||
| np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | ||||
| @@ -162,8 +161,8 @@ def test_release(): | |||||
| finally: | finally: | ||||
| gc.enable() | gc.enable() | ||||
| x = TensorWrapper([0.0]) | |||||
| dy = TensorWrapper(np.ones_like(x.numpy())) | |||||
| x = mge.Tensor([0.0]) | |||||
| dy = mge.Tensor(np.ones_like(x.numpy())) | |||||
| @check | @check | ||||
| def _(): | def _(): | ||||
| @@ -173,25 +172,25 @@ def test_release(): | |||||
| @check | @check | ||||
| def _(): | def _(): | ||||
| with Grad().wrt(x) as g: | |||||
| with Grad().wrt(x): | |||||
| pass | pass | ||||
| @check | @check | ||||
| def _(): | def _(): | ||||
| with Grad().wrt(x) as g: | |||||
| with Grad().wrt(x): | |||||
| y = x * x | y = x * x | ||||
| def test_grad_inplace(): | def test_grad_inplace(): | ||||
| x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
| x = TensorWrapper(x_np) | |||||
| x = mge.Tensor(x_np) | |||||
| grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
| y = mul(x, x) | y = mul(x, x) | ||||
| y *= y | y *= y | ||||
| grad(y, TensorWrapper(np.ones_like(x_np))) | |||||
| grad(y, mge.Tensor(np.ones_like(x_np))) | |||||
| np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | ||||
| @@ -199,16 +198,16 @@ def test_elemwise_add(): | |||||
| x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
| y_np = np.random.rand(10, 10).astype("float32") | y_np = np.random.rand(10, 10).astype("float32") | ||||
| dz_np = np.random.rand(10, 10).astype("float32") | dz_np = np.random.rand(10, 10).astype("float32") | ||||
| x = TensorWrapper(x_np) | |||||
| y = TensorWrapper(y_np) | |||||
| dz = TensorWrapper(dz_np) | |||||
| x = mge.Tensor(x_np) | |||||
| y = mge.Tensor(y_np) | |||||
| dz = mge.Tensor(dz_np) | |||||
| refs = {} | refs = {} | ||||
| def f(x, y): | def f(x, y): | ||||
| x = x * 2 | x = x * 2 | ||||
| refs["x"] = weakref.ref(x.__wrapped__) | |||||
| refs["y"] = weakref.ref(y.__wrapped__) | |||||
| refs["x"] = TensorWeakRef(x) | |||||
| refs["y"] = TensorWeakRef(y) | |||||
| return x + y | return x + y | ||||
| grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
| @@ -226,14 +225,14 @@ def test_elemwise_add(): | |||||
| def test_elemwise_relu(): | def test_elemwise_relu(): | ||||
| x_np = [1.0, -1.0] | x_np = [1.0, -1.0] | ||||
| dz_np = [1.0] | dz_np = [1.0] | ||||
| x = TensorWrapper(x_np) | |||||
| dz = TensorWrapper(dz_np) | |||||
| x = mge.Tensor(x_np) | |||||
| dz = mge.Tensor(dz_np) | |||||
| refs = {} | refs = {} | ||||
| def f(x): | def f(x): | ||||
| x = x * 2 | x = x * 2 | ||||
| refs["x"] = weakref.ref(x.__wrapped__) | |||||
| refs["x"] = TensorWeakRef(x) | |||||
| return relu(x) | return relu(x) | ||||
| grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
| @@ -258,7 +257,7 @@ def test_elemwise_relu_backward_fn(): | |||||
| def test_reshape(): | def test_reshape(): | ||||
| x_np = np.random.rand(2, 5).astype("float32") | x_np = np.random.rand(2, 5).astype("float32") | ||||
| x = TensorWrapper(x_np) | |||||
| x = mge.Tensor(x_np) | |||||
| grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
| y = x.reshape(5, 2) | y = x.reshape(5, 2) | ||||
| @@ -269,7 +268,7 @@ def test_reshape(): | |||||
| def test_subtensor(): | def test_subtensor(): | ||||
| x_np = np.random.rand(3, 3).astype("float32") | x_np = np.random.rand(3, 3).astype("float32") | ||||
| x = TensorWrapper(x_np) | |||||
| x = mge.Tensor(x_np) | |||||
| grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
| y = x[1:-1, :2] | y = x[1:-1, :2] | ||||
| @@ -282,7 +281,7 @@ def test_subtensor(): | |||||
| def test_IndexingMultiAxisVec(): | def test_IndexingMultiAxisVec(): | ||||
| x_np = np.random.rand(3, 3).astype("float32") | x_np = np.random.rand(3, 3).astype("float32") | ||||
| x = TensorWrapper(x_np) | |||||
| x = mge.Tensor(x_np) | |||||
| grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
| y = x[[0, 2], [0, 2]] | y = x[[0, 2], [0, 2]] | ||||
| @@ -295,7 +294,7 @@ def test_IndexingMultiAxisVec(): | |||||
| def test_AxisAddRemove(): | def test_AxisAddRemove(): | ||||
| x_np = np.random.rand(1, 5).astype("float32") | x_np = np.random.rand(1, 5).astype("float32") | ||||
| x = TensorWrapper(x_np) | |||||
| x = mge.Tensor(x_np) | |||||
| grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
| y = F.squeeze(F.expand_dims(x, 2), 0) | y = F.squeeze(F.expand_dims(x, 2), 0) | ||||
| @@ -308,7 +307,7 @@ def test_AxisAddRemove(): | |||||
| def test_Broadcast(): | def test_Broadcast(): | ||||
| x_np = np.random.rand(3, 3, 1).astype("float32") | x_np = np.random.rand(3, 3, 1).astype("float32") | ||||
| x = TensorWrapper(x_np) | |||||
| x = mge.Tensor(x_np) | |||||
| grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
| y = F.broadcast_to(x, (3, 3, 10)) | y = F.broadcast_to(x, (3, 3, 10)) | ||||
| @@ -319,7 +318,7 @@ def test_Broadcast(): | |||||
| def test_Reduce_sum(): | def test_Reduce_sum(): | ||||
| x_np = np.random.rand(3, 3).astype("float32") | x_np = np.random.rand(3, 3).astype("float32") | ||||
| x = TensorWrapper(x_np) | |||||
| x = mge.Tensor(x_np) | |||||
| grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
| y = x.sum(axis=0) | y = x.sum(axis=0) | ||||
| @@ -330,7 +329,7 @@ def test_Reduce_sum(): | |||||
| def test_Reduce_mean(): | def test_Reduce_mean(): | ||||
| x_np = np.random.rand(3, 3).astype("float32") | x_np = np.random.rand(3, 3).astype("float32") | ||||
| x = TensorWrapper(x_np) | |||||
| x = mge.Tensor(x_np) | |||||
| grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
| y = x.mean(axis=0) | y = x.mean(axis=0) | ||||
| @@ -11,30 +11,29 @@ import collections | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| import megengine.core.tensor.raw_tensor | |||||
| import megengine | |||||
| import megengine.tensor as Tensor | |||||
| from megengine.core._imperative_rt.core2 import apply | |||||
| from megengine.core._trace_option import use_symbolic_shape | from megengine.core._trace_option import use_symbolic_shape | ||||
| from megengine.core.ops import builtin | from megengine.core.ops import builtin | ||||
| from megengine.core.tensor import Tensor | |||||
| from megengine.core.tensor.core import apply | |||||
| from megengine.core.tensor.raw_tensor import RawTensor, as_raw_tensor | |||||
| def cvt_to_shape_desc(val, inpvar, config=None): | def cvt_to_shape_desc(val, inpvar, config=None): | ||||
| def as_tensor(val, device): | def as_tensor(val, device): | ||||
| assert device is not None, "can not infer device" | assert device is not None, "can not infer device" | ||||
| # TODO: should copy to appropriate device | # TODO: should copy to appropriate device | ||||
| val = as_raw_tensor(val, device=device) | |||||
| val = Tensor(val, device=device) | |||||
| return val | return val | ||||
| device = None | device = None | ||||
| if inpvar is not None: | if inpvar is not None: | ||||
| assert isinstance(inpvar, RawTensor) | |||||
| assert isinstance(inpvar, Tensor) | |||||
| device = device or inpvar.device | device = device or inpvar.device | ||||
| if config is not None: | if config is not None: | ||||
| device = device or config.device | device = device or config.device | ||||
| if isinstance(val, RawTensor): | |||||
| if isinstance(val, Tensor): | |||||
| return as_tensor(val, device) | return as_tensor(val, device) | ||||
| if not isinstance(val, collections.abc.Iterable): | if not isinstance(val, collections.abc.Iterable): | ||||
| @@ -43,7 +42,7 @@ def cvt_to_shape_desc(val, inpvar, config=None): | |||||
| components = [] | components = [] | ||||
| on_host = True | on_host = True | ||||
| for i in val: | for i in val: | ||||
| if isinstance(i, RawTensor): | |||||
| if isinstance(i, Tensor): | |||||
| on_host = False | on_host = False | ||||
| device = device or i.device | device = device or i.device | ||||
| else: | else: | ||||
| @@ -62,7 +61,7 @@ def cvt_to_shape_desc(val, inpvar, config=None): | |||||
| return as_tensor(shape, device) | return as_tensor(shape, device) | ||||
| for idx, v in enumerate(components): | for idx, v in enumerate(components): | ||||
| if not isinstance(v, RawTensor): | |||||
| if not isinstance(v, Tensor): | |||||
| vi = int(v) | vi = int(v) | ||||
| assert vi == v, "could not convert {} to int".format(v) | assert vi == v, "could not convert {} to int".format(v) | ||||
| v = vi | v = vi | ||||
| @@ -95,7 +94,7 @@ def canonize_inputs(inputs, *, config): | |||||
| # and is called with concat([a, b])) | # and is called with concat([a, b])) | ||||
| inputs = inputs[0] | inputs = inputs[0] | ||||
| if isinstance(inputs, RawTensor): | |||||
| if isinstance(inputs, Tensor): | |||||
| return [inputs] | return [inputs] | ||||
| old_inputs = inputs | old_inputs = inputs | ||||
| @@ -103,7 +102,7 @@ def canonize_inputs(inputs, *, config): | |||||
| get_comp_node = None | get_comp_node = None | ||||
| need_cvt = False | need_cvt = False | ||||
| for i in old_inputs: | for i in old_inputs: | ||||
| if isinstance(i, RawTensor): | |||||
| if isinstance(i, Tensor): | |||||
| get_comp_node = lambda cn=i.device: cn | get_comp_node = lambda cn=i.device: cn | ||||
| else: | else: | ||||
| need_cvt = True | need_cvt = True | ||||
| @@ -117,8 +116,8 @@ def canonize_inputs(inputs, *, config): | |||||
| return config.comp_node | return config.comp_node | ||||
| for idx, var in enumerate(inputs): | for idx, var in enumerate(inputs): | ||||
| if not isinstance(var, RawTensor): | |||||
| var = as_raw_tensor(var) | |||||
| if not isinstance(var, Tensor): | |||||
| var = Tensor(var) | |||||
| inputs[idx] = var | inputs[idx] = var | ||||
| return inputs | return inputs | ||||
| @@ -131,15 +130,15 @@ def invoke_op(op, inputs_, cvt_inputs=canonize_inputs): | |||||
| def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | ||||
| assert isinstance(inp, RawTensor) | |||||
| assert isinstance(inp, Tensor) | |||||
| if not isinstance(tuple_val, tuple): | if not isinstance(tuple_val, tuple): | ||||
| tuple_val = (tuple_val,) | tuple_val = (tuple_val,) | ||||
| def as_tensor(v): | def as_tensor(v): | ||||
| if not isinstance(v, RawTensor): | |||||
| if not isinstance(v, Tensor): | |||||
| vi = np.ascontiguousarray(v, dtype=np.int32) | vi = np.ascontiguousarray(v, dtype=np.int32) | ||||
| assert np.abs(vi - v).max() == 0, "bad index: {!r}".format(v) | assert np.abs(vi - v).max() == 0, "bad index: {!r}".format(v) | ||||
| v = as_raw_tensor(vi) | |||||
| v = Tensor(vi) | |||||
| return v | return v | ||||
| new_axes = [] | new_axes = [] | ||||
| @@ -275,14 +274,14 @@ def batched_incr_mesh_indexing(input, value, tuple_val): | |||||
| def test_transpose(): | def test_transpose(): | ||||
| x = np.arange(10).reshape(2, 5).astype("int32") | x = np.arange(10).reshape(2, 5).astype("int32") | ||||
| xx = as_raw_tensor(x) | |||||
| xx = Tensor(x) | |||||
| (yy,) = transpose(xx, pattern=[1, -1, 0]) | (yy,) = transpose(xx, pattern=[1, -1, 0]) | ||||
| np.testing.assert_equal(np.expand_dims(x.transpose(), axis=1), yy.numpy()) | np.testing.assert_equal(np.expand_dims(x.transpose(), axis=1), yy.numpy()) | ||||
| def test_broadcast(): | def test_broadcast(): | ||||
| x = np.arange(10).reshape(1, 10).astype("int32") | x = np.arange(10).reshape(1, 10).astype("int32") | ||||
| xx = as_raw_tensor(x) | |||||
| xx = Tensor(x) | |||||
| (yy,) = broadcast(xx, (10, 10)) | (yy,) = broadcast(xx, (10, 10)) | ||||
| np.testing.assert_equal(np.repeat(x, 10, 0), yy.numpy()) | np.testing.assert_equal(np.repeat(x, 10, 0), yy.numpy()) | ||||
| @@ -290,7 +289,7 @@ def test_broadcast(): | |||||
| def test_subtensor(): | def test_subtensor(): | ||||
| x = np.arange(25).reshape(5, 5).astype("int32") | x = np.arange(25).reshape(5, 5).astype("int32") | ||||
| d = np.arange(2).astype("int32") | d = np.arange(2).astype("int32") | ||||
| xx = as_raw_tensor(x) | |||||
| xx = Tensor(x) | |||||
| (yy0,) = subtensor(xx, (slice(0, 4, 2), 3)) | (yy0,) = subtensor(xx, (slice(0, 4, 2), 3)) | ||||
| (yy1,) = set_subtensor(xx, d, (slice(0, 4, 2), 3)) | (yy1,) = set_subtensor(xx, d, (slice(0, 4, 2), 3)) | ||||
| (yy2,) = incr_subtensor(xx, d, (slice(0, 4, 2), 3)) | (yy2,) = incr_subtensor(xx, d, (slice(0, 4, 2), 3)) | ||||
| @@ -309,7 +308,7 @@ def test_subtensor(): | |||||
| def test_advance_indexing(): | def test_advance_indexing(): | ||||
| x = np.arange(25).reshape(5, 5).astype("int32") | x = np.arange(25).reshape(5, 5).astype("int32") | ||||
| d = np.arange(15).reshape(3, 5).astype("int32") | d = np.arange(15).reshape(3, 5).astype("int32") | ||||
| xx = as_raw_tensor(x) | |||||
| xx = Tensor(x) | |||||
| (yy0,) = advance_indexing(xx, ((0, 4, 2), slice(None, None, None))) | (yy0,) = advance_indexing(xx, ((0, 4, 2), slice(None, None, None))) | ||||
| (yy1,) = set_advance_indexing(xx, d, ((0, 4, 2), slice(None, None, None))) | (yy1,) = set_advance_indexing(xx, d, ((0, 4, 2), slice(None, None, None))) | ||||
| (yy2,) = incr_advance_indexing(xx, d, ((0, 4, 2), slice(None, None, None))) | (yy2,) = incr_advance_indexing(xx, d, ((0, 4, 2), slice(None, None, None))) | ||||
| @@ -328,7 +327,7 @@ def test_advance_indexing(): | |||||
| def test_mesh_indexing(): | def test_mesh_indexing(): | ||||
| x = np.arange(25).reshape(5, 5).astype("int32") | x = np.arange(25).reshape(5, 5).astype("int32") | ||||
| d = np.arange(6).reshape(3, 2).astype("int32") | d = np.arange(6).reshape(3, 2).astype("int32") | ||||
| xx = as_raw_tensor(x) | |||||
| xx = Tensor(x) | |||||
| (yy0,) = mesh_indexing(xx, (slice(0, 5, 2), (1, 3))) | (yy0,) = mesh_indexing(xx, (slice(0, 5, 2), (1, 3))) | ||||
| (yy1,) = set_mesh_indexing(xx, d, (slice(0, 5, 2), (1, 3))) | (yy1,) = set_mesh_indexing(xx, d, (slice(0, 5, 2), (1, 3))) | ||||
| (yy2,) = incr_mesh_indexing(xx, d, (slice(0, 5, 2), (1, 3))) | (yy2,) = incr_mesh_indexing(xx, d, (slice(0, 5, 2), (1, 3))) | ||||
| @@ -355,7 +354,7 @@ def test_mesh_indexing(): | |||||
| def test_batched_mesh_indexing(): | def test_batched_mesh_indexing(): | ||||
| x = np.arange(24).reshape(2, 3, 4).astype("int32") | x = np.arange(24).reshape(2, 3, 4).astype("int32") | ||||
| d = np.arange(12).reshape(2, 2, 3).astype("int32") | d = np.arange(12).reshape(2, 2, 3).astype("int32") | ||||
| xx = as_raw_tensor(x) | |||||
| xx = Tensor(x) | |||||
| s = [(0, 1, 2), (1, 2, 3)] | s = [(0, 1, 2), (1, 2, 3)] | ||||
| (yy0,) = batched_mesh_indexing(xx, (slice(None, None, None), [(0, 2)] * 2, s)) | (yy0,) = batched_mesh_indexing(xx, (slice(None, None, None), [(0, 2)] * 2, s)) | ||||
| (yy1,) = batched_set_mesh_indexing( | (yy1,) = batched_set_mesh_indexing( | ||||
| @@ -9,12 +9,12 @@ | |||||
| import numpy as np | import numpy as np | ||||
| from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 | from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 | ||||
| from megengine.core.tensor.tensor_wrapper import TensorWrapper | |||||
| from megengine.tensor import Tensor | |||||
| def test_basic(): | def test_basic(): | ||||
| x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
| x = TensorWrapper(x_np) | |||||
| x = Tensor(x_np) | |||||
| y = x * x | y = x * x | ||||
| y_np = y.numpy() | y_np = y.numpy() | ||||
| np.testing.assert_almost_equal(y_np, x_np * x_np) | np.testing.assert_almost_equal(y_np, x_np * x_np) | ||||
| @@ -22,15 +22,15 @@ def test_basic(): | |||||
| def test_literal_arith(): | def test_literal_arith(): | ||||
| x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
| x = TensorWrapper(x_np) | |||||
| x = Tensor(x_np) | |||||
| y = x * 2 | y = x * 2 | ||||
| y_np = y.numpy() | y_np = y.numpy() | ||||
| np.testing.assert_almost_equal(y_np, x_np * 2) | np.testing.assert_almost_equal(y_np, x_np * 2) | ||||
| def test_matmul(): | def test_matmul(): | ||||
| A = TensorWrapper(np.random.rand(5, 7).astype("float32")) | |||||
| B = TensorWrapper(np.random.rand(7, 10).astype("float32")) | |||||
| A = Tensor(np.random.rand(5, 7).astype("float32")) | |||||
| B = Tensor(np.random.rand(7, 10).astype("float32")) | |||||
| C = A @ B | C = A @ B | ||||
| np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6) | np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6) | ||||
| @@ -38,7 +38,7 @@ def test_matmul(): | |||||
| def test_reduce(): | def test_reduce(): | ||||
| def test_x(x_np): | def test_x(x_np): | ||||
| for m in ["sum", "prod", "min", "max", "mean"]: | for m in ["sum", "prod", "min", "max", "mean"]: | ||||
| x = TensorWrapper(x_np) | |||||
| x = Tensor(x_np) | |||||
| y = getattr(x, m)(axis=-1, keepdims=True) | y = getattr(x, m)(axis=-1, keepdims=True) | ||||
| np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6) | np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6) | ||||
| @@ -49,7 +49,7 @@ def test_reduce(): | |||||
| def test_set_subtensor(): | def test_set_subtensor(): | ||||
| x = TensorWrapper([1, 2, 3]) | |||||
| x = Tensor([1, 2, 3]) | |||||
| x[:] = [1, 1, 1] | x[:] = [1, 1, 1] | ||||
| np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6) | np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6) | ||||
| x[[0, 2]] = [3, 2] | x[[0, 2]] = [3, 2] | ||||
| @@ -60,7 +60,7 @@ def test_set_subtensor(): | |||||
| def test_computing_with_numpy_array(): | def test_computing_with_numpy_array(): | ||||
| x = np.array([1, 2, 3], dtype=np.int32) | x = np.array([1, 2, 3], dtype=np.int32) | ||||
| xx = TensorWrapper(x, device="cpu0") | |||||
| xx = Tensor(x, device="cpu0") | |||||
| y = np.array([1, 0, 3], dtype=np.int32) | y = np.array([1, 0, 3], dtype=np.int32) | ||||
| assert np.add(xx, y).device == xx.device | assert np.add(xx, y).device == xx.device | ||||
| np.testing.assert_equal(np.add(xx, y).numpy(), np.add(x, y)) | np.testing.assert_equal(np.add(xx, y).numpy(), np.add(x, y)) | ||||
| @@ -70,12 +70,12 @@ def test_computing_with_numpy_array(): | |||||
| def test_transpose(): | def test_transpose(): | ||||
| x = np.random.rand(2, 5).astype("float32") | x = np.random.rand(2, 5).astype("float32") | ||||
| xx = TensorWrapper(x) | |||||
| xx = Tensor(x) | |||||
| np.testing.assert_almost_equal(xx.T.numpy(), x.T) | np.testing.assert_almost_equal(xx.T.numpy(), x.T) | ||||
| def test_as_type(): | def test_as_type(): | ||||
| x = TensorWrapper([1, 2, 3], dtype=np.float32) | |||||
| x = Tensor([1, 2, 3], dtype=np.float32) | |||||
| y = x.astype(qint8(0.1)) | y = x.astype(qint8(0.1)) | ||||
| np.testing.assert_almost_equal(get_scale(y.dtype), 0.1) | np.testing.assert_almost_equal(get_scale(y.dtype), 0.1) | ||||
| z = y.astype(qint8(0.2)) | z = y.astype(qint8(0.2)) | ||||
| @@ -312,7 +312,7 @@ def test_device(): | |||||
| np.testing.assert_almost_equal(y1.numpy(), y2.numpy()) | np.testing.assert_almost_equal(y1.numpy(), y2.numpy()) | ||||
| y3 = F.eye(x.shape, dtype="float32", device="xpux") | y3 = F.eye(x.shape, dtype="float32", device="xpux") | ||||
| y4 = F.eye(x.shape, dtype="float32", device=x.device.to_c()) | |||||
| y4 = F.eye(x.shape, dtype="float32", device=x.device) | |||||
| np.testing.assert_almost_equal(y3.numpy(), y4.numpy()) | np.testing.assert_almost_equal(y3.numpy(), y4.numpy()) | ||||
| y5 = F.full((3, 2), 4, device=x.device) | y5 = F.full((3, 2), 4, device=x.device) | ||||
| @@ -14,7 +14,7 @@ | |||||
| #include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
| #include "./op_trait.h" | #include "./op_trait.h" | ||||
| #include "./proxy_graph_detail.h" | |||||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace imperative { | namespace imperative { | ||||
| @@ -13,7 +13,7 @@ | |||||
| #if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
| #include "../op_trait.h" | #include "../op_trait.h" | ||||
| #include "../proxy_graph_detail.h" | |||||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||||
| #include "megbrain/opr/mm_handler.h" | #include "megbrain/opr/mm_handler.h" | ||||
| #include "megbrain/utils/hash.h" | #include "megbrain/utils/hash.h" | ||||
| #endif // MGB_ENABLE_OPR_MM | #endif // MGB_ENABLE_OPR_MM | ||||
| @@ -13,7 +13,7 @@ | |||||
| #if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
| #include "../op_trait.h" | #include "../op_trait.h" | ||||
| #include "../proxy_graph_detail.h" | |||||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||||
| #include "megbrain/opr/io_remote.h" | #include "megbrain/opr/io_remote.h" | ||||
| #include "megbrain/opr/mm_handler.h" | #include "megbrain/opr/mm_handler.h" | ||||
| #endif // MGB_ENABLE_OPR_MM | #endif // MGB_ENABLE_OPR_MM | ||||
| @@ -13,7 +13,7 @@ | |||||
| #include "megbrain/serialization/opr_load_dump.h" | #include "megbrain/serialization/opr_load_dump.h" | ||||
| #include "../op_trait.h" | #include "../op_trait.h" | ||||
| #include "../proxy_graph_detail.h" | |||||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace imperative { | namespace imperative { | ||||
| @@ -10,7 +10,7 @@ | |||||
| */ | */ | ||||
| #include "./proxy_graph.h" | #include "./proxy_graph.h" | ||||
| #include "./proxy_graph_detail.h" | |||||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace imperative { | namespace imperative { | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * \file imperative/src/impl/proxy_graph_detail.h | |||||
| * \file imperative/src/include/megbrain/imperative/proxy_graph_detail.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
| * | * | ||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||