| @@ -26,4 +26,6 @@ def set_symbolic_shape(option: bool): | |||||
| """ Sets whether tensor.shape returns a tensor instead of a tuple | """ Sets whether tensor.shape returns a tensor instead of a tuple | ||||
| """ | """ | ||||
| global _use_symbolic_shape | global _use_symbolic_shape | ||||
| _org = _use_symbolic_shape | |||||
| _use_symbolic_shape = option | _use_symbolic_shape = option | ||||
| return _org | |||||
| @@ -14,7 +14,7 @@ from .._trace_option import use_symbolic_shape | |||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from ..ops.special import Const | from ..ops.special import Const | ||||
| from .core import TensorBase, TensorWrapperBase, apply | from .core import TensorBase, TensorWrapperBase, apply | ||||
| from .utils import astensor1d, make_shape_tuple | |||||
| from .utils import astensor1d, isscalar, make_shape_tuple | |||||
| def remove_ellipsis(tensor, tuple_val): | def remove_ellipsis(tensor, tuple_val): | ||||
| @@ -89,9 +89,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
| if not isinstance(tuple_val, tuple): | if not isinstance(tuple_val, tuple): | ||||
| tuple_val = (tuple_val,) | tuple_val = (tuple_val,) | ||||
| ndim_indexed = 0 | ndim_indexed = 0 | ||||
| ndim_indexed_scalar = 0 | |||||
| for i in tuple_val: | for i in tuple_val: | ||||
| if not i is Ellipsis: | if not i is Ellipsis: | ||||
| ndim_indexed += 1 if not hasattr(i, "ndim") else i.ndim | ndim_indexed += 1 if not hasattr(i, "ndim") else i.ndim | ||||
| if isscalar(i): | |||||
| ndim_indexed_scalar += 1 | |||||
| if ndim_indexed > inp.ndim: | if ndim_indexed > inp.ndim: | ||||
| raise IndexError( | raise IndexError( | ||||
| "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( | "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( | ||||
| @@ -103,15 +107,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
| use_subtensor = True | use_subtensor = True | ||||
| inp, tuple_val = check_bool_index(inp, tuple_val) | inp, tuple_val = check_bool_index(inp, tuple_val) | ||||
| def is_scalar(d): | |||||
| if isinstance(i, int): | |||||
| return True | |||||
| if type(d).__module__ == np.__name__: | |||||
| return np.isscalar(d) | |||||
| # if isinstance(d, (TensorBase, TensorWrapperBase)): | |||||
| # return d.shape == (1,) | |||||
| return False | |||||
| new_axes = [] | new_axes = [] | ||||
| tensors = [] | tensors = [] | ||||
| items = [] | items = [] | ||||
| @@ -134,7 +129,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
| continue | continue | ||||
| if ( | if ( | ||||
| not is_scalar(i) | |||||
| not isscalar(i) | |||||
| and not i is np.newaxis | and not i is np.newaxis | ||||
| and not i is Ellipsis | and not i is Ellipsis | ||||
| and not isinstance(i, slice) | and not isinstance(i, slice) | ||||
| @@ -191,7 +186,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
| items.append(item) | items.append(item) | ||||
| if new_axes: | if new_axes: | ||||
| raise IndexError("newaxis is not allowed here") | raise IndexError("newaxis is not allowed here") | ||||
| return inp, tensors, items, use_subtensor | |||||
| return inp, tensors, items, use_subtensor, ndim_indexed_scalar == inp.ndim | |||||
| def try_condtake(tensor, index): | def try_condtake(tensor, index): | ||||
| @@ -217,11 +212,11 @@ def getitem(tensor, index): | |||||
| try_result = try_condtake(tensor, index) | try_result = try_condtake(tensor, index) | ||||
| if len(try_result) == 2: | if len(try_result) == 2: | ||||
| return try_result[0] | return try_result[0] | ||||
| tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) | |||||
| tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) | |||||
| for v in tensors: | for v in tensors: | ||||
| if isinstance(v.shape, v.__class__): | if isinstance(v.shape, v.__class__): | ||||
| break | break | ||||
| if v.shape[0] == 0: | |||||
| if len(v.shape) > 0 and v.shape[0] == 0: | |||||
| (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( | (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( | ||||
| tensor | tensor | ||||
| ) | ) | ||||
| @@ -231,6 +226,8 @@ def getitem(tensor, index): | |||||
| else: | else: | ||||
| op = builtin.IndexingMultiAxisVec(items=items) | op = builtin.IndexingMultiAxisVec(items=items) | ||||
| (result,) = apply(op, tensor, *tensors) | (result,) = apply(op, tensor, *tensors) | ||||
| if ret_scalar: | |||||
| result.__wrapped__._data._isscalar = True | |||||
| return result | return result | ||||
| @@ -245,9 +242,9 @@ def setitem(tensor, index, value): | |||||
| if not isinstance(value, (TensorBase, TensorWrapperBase)): | if not isinstance(value, (TensorBase, TensorWrapperBase)): | ||||
| op = Const(value, dtype=tensor.dtype, device=tensor.device) | op = Const(value, dtype=tensor.dtype, device=tensor.device) | ||||
| (value,) = op(tensor) | (value,) = op(tensor) | ||||
| tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) | |||||
| tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | |||||
| for v in tensors: | for v in tensors: | ||||
| if v.shape[0] == 0: | |||||
| if len(v.shape) > 0 and v.shape[0] == 0: | |||||
| return tensor | return tensor | ||||
| if use_subtensor: | if use_subtensor: | ||||
| op = builtin.Subtensor(items=items) | op = builtin.Subtensor(items=items) | ||||
| @@ -102,8 +102,9 @@ class Graph(_imperative_rt.ComputingGraph): | |||||
| class VarNode(TensorBase): | class VarNode(TensorBase): | ||||
| def __init__(self, node: _imperative_rt.VarNode): | |||||
| def __init__(self, node: _imperative_rt.VarNode, isscalar=False): | |||||
| self._node = node | self._node = node | ||||
| self._isscalar = isscalar | |||||
| if hasattr(self.graph, "_var_cache"): | if hasattr(self.graph, "_var_cache"): | ||||
| self.graph._var_cache[node] = self | self.graph._var_cache[node] = self | ||||
| @@ -33,8 +33,9 @@ class RawTensor(TensorBase): | |||||
| _del_cb = None | _del_cb = None | ||||
| _handle = None | _handle = None | ||||
| def __init__(self, handle=None): | |||||
| def __init__(self, handle=None, isscalar=False): | |||||
| self._handle = handle | self._handle = handle | ||||
| self._isscalar = isscalar | |||||
| if handle is not None: | if handle is not None: | ||||
| if self._init_cb: | if self._init_cb: | ||||
| self._init_cb() | self._init_cb() | ||||
| @@ -49,10 +50,15 @@ class RawTensor(TensorBase): | |||||
| @property | @property | ||||
| def shape(self): | def shape(self): | ||||
| if self._isscalar: | |||||
| return () | |||||
| return get_shape(self._handle) | return get_shape(self._handle) | ||||
| def numpy(self): | def numpy(self): | ||||
| return get_value(self._handle) | |||||
| ret = get_value(self._handle) | |||||
| if self._isscalar: | |||||
| ret = ret.squeeze() | |||||
| return ret | |||||
| def _dev_tensor(self): | def _dev_tensor(self): | ||||
| return _get_dev_tensor(self._handle) | return _get_dev_tensor(self._handle) | ||||
| @@ -102,7 +108,7 @@ def _(array: np.ndarray, dtype=None, device=None): | |||||
| device = None if device is None else as_device(device).to_c() | device = None if device is None else as_device(device).to_c() | ||||
| if 0 in array.strides: | if 0 in array.strides: | ||||
| array = array.squeeze().reshape(array.shape) | array = array.squeeze().reshape(array.shape) | ||||
| return RawTensor(put(array, dtype=dtype, device=device)) | |||||
| return RawTensor(put(array, dtype=dtype, device=device), isscalar=(array.ndim == 0)) | |||||
| @as_raw_tensor.register(RawTensor) | @as_raw_tensor.register(RawTensor) | ||||
| @@ -21,7 +21,9 @@ from .indexing import getitem as _getitem | |||||
| from .indexing import setitem as _setitem | from .indexing import setitem as _setitem | ||||
| from .raw_tensor import RawTensor, as_raw_tensor | from .raw_tensor import RawTensor, as_raw_tensor | ||||
| from .tensor import Tensor | from .tensor import Tensor | ||||
| from .utils import isscalar | |||||
| from .utils import make_shape_tuple as _make_shape_tuple | from .utils import make_shape_tuple as _make_shape_tuple | ||||
| from .utils import setscalar | |||||
| _ElwMod = Elemwise.Mode | _ElwMod = Elemwise.Mode | ||||
| @@ -39,6 +41,13 @@ def _elwise(*args, mode): | |||||
| ) | ) | ||||
| args = utils.convert_inputs(*args) | args = utils.convert_inputs(*args) | ||||
| (result,) = apply(op, *args) | (result,) = apply(op, *args) | ||||
| _isscalar = True | |||||
| for i in args: | |||||
| if isscalar(i) == False: | |||||
| _isscalar = False | |||||
| break | |||||
| if _isscalar: | |||||
| setscalar(result) | |||||
| return result | return result | ||||
| @@ -153,6 +162,8 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||||
| param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis)) | param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis)) | ||||
| op = builtin.AxisAddRemove(param=param) | op = builtin.AxisAddRemove(param=param) | ||||
| (result,) = apply(op, inp) | (result,) = apply(op, inp) | ||||
| if len(axis) == inp.ndim: | |||||
| setscalar(result) | |||||
| return result | return result | ||||
| @@ -189,6 +200,8 @@ def _reduce(mode): | |||||
| if self.dtype == np.bool_: | if self.dtype == np.bool_: | ||||
| if mode in ["MIN", "MAX"]: | if mode in ["MIN", "MAX"]: | ||||
| result = result.astype("bool") | result = result.astype("bool") | ||||
| if axis is None or self.ndim == 1: | |||||
| setscalar(result) | |||||
| return result | return result | ||||
| return f | return f | ||||
| @@ -321,9 +334,7 @@ class ArrayMethodMixin(abc.ABC): | |||||
| __complex__ = lambda self: complex(self.item()) | __complex__ = lambda self: complex(self.item()) | ||||
| def __len__(self): | def __len__(self): | ||||
| shape = self.shape | |||||
| if use_symbolic_shape(): | |||||
| shape = shape.numpy() | |||||
| shape = self.__wrapped__.shape | |||||
| if shape: | if shape: | ||||
| return int(shape[0]) | return int(shape[0]) | ||||
| raise TypeError("ndim is 0") | raise TypeError("ndim is 0") | ||||
| @@ -344,18 +355,17 @@ class ArrayMethodMixin(abc.ABC): | |||||
| @property | @property | ||||
| def ndim(self): | def ndim(self): | ||||
| shape = self.shape | |||||
| if isinstance(shape, self.__class__): | |||||
| # XXX: assume ndim is not changed during trace | |||||
| ndim = shape.__wrapped__.shape[0] | |||||
| return ndim | |||||
| shape = self.__wrapped__.shape | |||||
| if shape is None: | |||||
| raise ValueError("unkown ndim") | |||||
| return len(shape) | return len(shape) | ||||
| @property | @property | ||||
| def size(self): | def size(self): | ||||
| if use_symbolic_shape(): | |||||
| return self.shape.prod() | |||||
| return np.prod(self.shape).item() | |||||
| shape = self.shape | |||||
| if shape.__class__ is tuple: | |||||
| return np.prod(self.shape).item() | |||||
| return shape.prod() | |||||
| @property | @property | ||||
| def T(self): | def T(self): | ||||
| @@ -416,8 +426,8 @@ class ArrayMethodMixin(abc.ABC): | |||||
| .. testoutput:: | .. testoutput:: | ||||
| [2] | |||||
| [10.] | |||||
| 2 | |||||
| 10. | |||||
| """ | """ | ||||
| return _reduce("SUM")(self, axis, keepdims) | return _reduce("SUM")(self, axis, keepdims) | ||||
| @@ -444,10 +454,10 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase): | |||||
| @property | @property | ||||
| def shape(self): | def shape(self): | ||||
| if use_symbolic_shape(): | |||||
| return apply(GetVarShape(), self)[0] | |||||
| else: | |||||
| return self.__wrapped__.shape | |||||
| shape = self.__wrapped__.shape | |||||
| if shape == () or not use_symbolic_shape(): | |||||
| return shape | |||||
| return apply(GetVarShape(), self)[0] | |||||
| @property | @property | ||||
| def device(self): | def device(self): | ||||
| @@ -133,7 +133,9 @@ def concatenate(inputs, axis=0, *, device=None): | |||||
| def astype(x, dtype): | def astype(x, dtype): | ||||
| dtype = np.dtype(dtype) | dtype = np.dtype(dtype) | ||||
| if not is_equal(x.dtype, dtype): | if not is_equal(x.dtype, dtype): | ||||
| isscalar = x.__wrapped__._data._isscalar | |||||
| (x,) = apply(builtin.TypeCvt(param=dtype), x) | (x,) = apply(builtin.TypeCvt(param=dtype), x) | ||||
| x.__wrapped__._data._isscalar = isscalar | |||||
| return x | return x | ||||
| @@ -176,13 +178,29 @@ def result_type(*args): | |||||
| def isscalar(x): | def isscalar(x): | ||||
| try: | |||||
| return x.ndim == 0 | |||||
| except: | |||||
| pass | |||||
| if isinstance(x, TensorWrapperBase): | |||||
| x = x.__wrapped__ | |||||
| if hasattr(x, "_isscalar"): | |||||
| return x._isscalar | |||||
| if isinstance(x, TensorBase): | |||||
| return x._data._isscalar | |||||
| return np.isscalar(x) | return np.isscalar(x) | ||||
| def setscalar(x): | |||||
| if isinstance(x, TensorWrapperBase): | |||||
| x = x.__wrapped__ | |||||
| if hasattr(x, "_isscalar"): | |||||
| x._isscalar = True | |||||
| elif isinstance(x, TensorBase): | |||||
| x._data._isscalar = True | |||||
| else: | |||||
| raise NotImplementedError("Unsupport type {}".format(type(x))) | |||||
| def astensor1d(x, *reference, dtype=None, device=None): | def astensor1d(x, *reference, dtype=None, device=None): | ||||
| """ | """ | ||||
| Convert something to 1D tensor. Support following types | Convert something to 1D tensor. Support following types | ||||
| @@ -195,8 +213,8 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
| except AttributeError: | except AttributeError: | ||||
| pass | pass | ||||
| else: | else: | ||||
| if ndim != 1: | |||||
| raise ValueError("ndim != 1: %d" % ndim) | |||||
| if ndim != 0 and ndim != 1: | |||||
| raise ValueError("ndim != 1 or 0, get : %d" % ndim) | |||||
| if not isinstance(x, (TensorBase, TensorWrapperBase)): | if not isinstance(x, (TensorBase, TensorWrapperBase)): | ||||
| (x,) = Const(x, dtype=dtype, device=device)(*reference) | (x,) = Const(x, dtype=dtype, device=device)(*reference) | ||||
| return x | return x | ||||
| @@ -216,7 +234,11 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
| def _expand_int(s, i): | def _expand_int(s, i): | ||||
| if isinstance(i, (TensorBase, TensorWrapperBase)): | if isinstance(i, (TensorBase, TensorWrapperBase)): | ||||
| s += list(i.numpy()) | |||||
| i_np = i.numpy() | |||||
| if i_np.ndim == 0: | |||||
| s.append(int(i_np)) | |||||
| else: | |||||
| s += list(i_np) | |||||
| return | return | ||||
| if isinstance(i, Iterable): | if isinstance(i, Iterable): | ||||
| for ii in i: | for ii in i: | ||||
| @@ -63,8 +63,12 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): | |||||
| """ | """ | ||||
| op = ParamPackSplit() | op = ParamPackSplit() | ||||
| op.offsets = offsets | op.offsets = offsets | ||||
| op.shapes = shapes | |||||
| return apply(op, inp) | |||||
| op.shapes = [s or (1,) for s in shapes] | |||||
| outputs = apply(op, inp) | |||||
| for s, x in zip(shapes, outputs): | |||||
| if not s: | |||||
| x._isscalar = True | |||||
| return outputs | |||||
| def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list): | def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list): | ||||
| @@ -13,6 +13,7 @@ from ..core.ops import builtin | |||||
| from ..core.ops.builtin import Elemwise | from ..core.ops.builtin import Elemwise | ||||
| from ..core.tensor import megbrain_graph, utils | from ..core.tensor import megbrain_graph, utils | ||||
| from ..core.tensor.core import apply | from ..core.tensor.core import apply | ||||
| from ..core.tensor.utils import isscalar, setscalar | |||||
| from ..device import get_default_device | from ..device import get_default_device | ||||
| from ..jit.tracing import is_tracing | from ..jit.tracing import is_tracing | ||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| @@ -105,7 +106,14 @@ def _elwise(*args, mode): | |||||
| args = utils.convert_inputs(*args) | args = utils.convert_inputs(*args) | ||||
| if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"): | if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"): | ||||
| args = tuple(map(lambda x: x.astype("float32"), args)) | args = tuple(map(lambda x: x.astype("float32"), args)) | ||||
| _isscalar = True | |||||
| for i in args: | |||||
| if isscalar(i) == False: | |||||
| _isscalar = False | |||||
| break | |||||
| (result,) = apply(op, *args) | (result,) = apply(op, *args) | ||||
| if _isscalar: | |||||
| setscalar(result) | |||||
| return result | return result | ||||
| @@ -63,7 +63,7 @@ def l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||||
| .. testoutput:: | .. testoutput:: | ||||
| [2.75] | |||||
| 2.75 | |||||
| """ | """ | ||||
| diff = pred - label | diff = pred - label | ||||
| @@ -115,7 +115,7 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor: | |||||
| .. testoutput:: | .. testoutput:: | ||||
| [9.75] | |||||
| 9.75 | |||||
| """ | """ | ||||
| diff = pred - label | diff = pred - label | ||||
| @@ -170,7 +170,7 @@ def cross_entropy( | |||||
| .. testoutput:: | .. testoutput:: | ||||
| [0.6931] | |||||
| 0.6931 | |||||
| """ | """ | ||||
| n0 = pred.ndim | n0 = pred.ndim | ||||
| @@ -226,7 +226,7 @@ def binary_cross_entropy( | |||||
| .. testoutput:: | .. testoutput:: | ||||
| [0.6931] | |||||
| 0.6931 | |||||
| """ | """ | ||||
| if not with_logits: | if not with_logits: | ||||
| @@ -265,7 +265,7 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: | |||||
| .. testoutput:: | .. testoutput:: | ||||
| [1.5] | |||||
| 1.5 | |||||
| """ | """ | ||||
| assert norm in ["L1", "L2"], "norm must be L1 or L2" | assert norm in ["L1", "L2"], "norm must be L1 or L2" | ||||
| @@ -155,7 +155,7 @@ def sum( | |||||
| .. testoutput:: | .. testoutput:: | ||||
| [21] | |||||
| 21 | |||||
| """ | """ | ||||
| return inp.sum(axis=axis, keepdims=keepdims) | return inp.sum(axis=axis, keepdims=keepdims) | ||||
| @@ -189,7 +189,7 @@ def prod( | |||||
| .. testoutput:: | .. testoutput:: | ||||
| [720] | |||||
| 720 | |||||
| """ | """ | ||||
| return inp.prod(axis=axis, keepdims=keepdims) | return inp.prod(axis=axis, keepdims=keepdims) | ||||
| @@ -226,7 +226,7 @@ def mean( | |||||
| .. testoutput:: | .. testoutput:: | ||||
| [3.5] | |||||
| 3.5 | |||||
| """ | """ | ||||
| return inp.mean(axis=axis, keepdims=keepdims) | return inp.mean(axis=axis, keepdims=keepdims) | ||||
| @@ -263,7 +263,7 @@ def var( | |||||
| .. testoutput:: | .. testoutput:: | ||||
| [2.9167] | |||||
| 2.9167 | |||||
| """ | """ | ||||
| if axis is None: | if axis is None: | ||||
| m = mean(inp, axis=axis, keepdims=False) | m = mean(inp, axis=axis, keepdims=False) | ||||
| @@ -340,7 +340,7 @@ def min( | |||||
| .. testoutput:: | .. testoutput:: | ||||
| [1] | |||||
| 1 | |||||
| """ | """ | ||||
| return inp.min(axis=axis, keepdims=keepdims) | return inp.min(axis=axis, keepdims=keepdims) | ||||
| @@ -377,7 +377,7 @@ def max( | |||||
| .. testoutput:: | .. testoutput:: | ||||
| [6] | |||||
| 6 | |||||
| """ | """ | ||||
| return inp.max(axis=axis, keepdims=keepdims) | return inp.max(axis=axis, keepdims=keepdims) | ||||
| @@ -412,7 +412,7 @@ def norm( | |||||
| .. testoutput:: | .. testoutput:: | ||||
| [4.3589] | |||||
| 4.3589 | |||||
| """ | """ | ||||
| if axis is None: | if axis is None: | ||||
| @@ -460,7 +460,7 @@ def argmin( | |||||
| .. testoutput:: | .. testoutput:: | ||||
| [0] | |||||
| 0 | |||||
| """ | """ | ||||
| if isinstance(axis, collections.abc.Iterable): | if isinstance(axis, collections.abc.Iterable): | ||||
| @@ -519,7 +519,7 @@ def argmax( | |||||
| .. testoutput:: | .. testoutput:: | ||||
| [5] | |||||
| 5 | |||||
| """ | """ | ||||
| if isinstance(axis, collections.abc.Iterable): | if isinstance(axis, collections.abc.Iterable): | ||||
| @@ -111,6 +111,8 @@ def full(shape, value, dtype="float32", device=None): | |||||
| (x,) = Const(value, dtype=dtype, device=device)( | (x,) = Const(value, dtype=dtype, device=device)( | ||||
| Tensor(value, dtype=dtype, device=device) | Tensor(value, dtype=dtype, device=device) | ||||
| ) | ) | ||||
| if len(shape) == 0: # scalar | |||||
| return x | |||||
| return broadcast_to(x, shape) | return broadcast_to(x, shape) | ||||
| @@ -53,7 +53,7 @@ def topk_accuracy( | |||||
| .. testoutput:: | .. testoutput:: | ||||
| [0.] [0.375] | |||||
| 0.0 0.375 | |||||
| """ | """ | ||||
| if isinstance(topk, int): | if isinstance(topk, int): | ||||
| topk = (topk,) | topk = (topk,) | ||||
| @@ -168,8 +168,6 @@ class trace: | |||||
| self._output_bindings = None | self._output_bindings = None | ||||
| self._output_names = None | self._output_names = None | ||||
| set_symbolic_shape(self._symbolic_shape) | |||||
| def _new_handle(self): | def _new_handle(self): | ||||
| handle = len(self._tinfo) | handle = len(self._tinfo) | ||||
| info = TensorInfo() | info = TensorInfo() | ||||
| @@ -368,6 +366,7 @@ class trace: | |||||
| interrupted = False | interrupted = False | ||||
| def do_enter(): | def do_enter(): | ||||
| self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape) | |||||
| self._set_active(True) | self._set_active(True) | ||||
| if self._untraced: | if self._untraced: | ||||
| self._init_trace(self._symbolic) | self._init_trace(self._symbolic) | ||||
| @@ -423,6 +422,8 @@ class trace: | |||||
| apply.disable(apply_compiled_mode) | apply.disable(apply_compiled_mode) | ||||
| apply.disable(apply_const_compiled_mode) | apply.disable(apply_const_compiled_mode) | ||||
| self._set_active(False) | self._set_active(False) | ||||
| # Restore global variable | |||||
| set_symbolic_shape(self._save_symbolic_shape) | |||||
| def do_exit(): | def do_exit(): | ||||
| if not self._untraced and self._pc != len(self._seq): | if not self._untraced and self._pc != len(self._seq): | ||||
| @@ -498,7 +499,7 @@ class trace: | |||||
| opnode = info.data_setter = G.InputNode( | opnode = info.data_setter = G.InputNode( | ||||
| device=info.device, | device=info.device, | ||||
| dtype=info.dtype, | dtype=info.dtype, | ||||
| shape=info.shape, | |||||
| shape=info.shape or (1,), | |||||
| graph=graph, | graph=graph, | ||||
| use_static_shape=_input_node_use_static_shape(), | use_static_shape=_input_node_use_static_shape(), | ||||
| ) | ) | ||||
| @@ -544,7 +545,7 @@ class trace: | |||||
| *links, | *links, | ||||
| device=info.device, | device=info.device, | ||||
| dtype=info.dtype, | dtype=info.dtype, | ||||
| shape=info.shape, | |||||
| shape=info.shape or (1,), | |||||
| graph=graph, | graph=graph, | ||||
| use_static_shape=_input_node_use_static_shape(), | use_static_shape=_input_node_use_static_shape(), | ||||
| ) | ) | ||||
| @@ -719,13 +720,13 @@ class trace: | |||||
| h2v[h] = graph.make_h2d( | h2v[h] = graph.make_h2d( | ||||
| dtype=info.dtype, | dtype=info.dtype, | ||||
| device=dumped_device, | device=dumped_device, | ||||
| shape=info.shape, | |||||
| shape=info.shape or (1,), | |||||
| name=arg_names[i] if arg_names else None, | name=arg_names[i] if arg_names else None, | ||||
| ) | ) | ||||
| for k, h in self._kwarg_bindings.items(): | for k, h in self._kwarg_bindings.items(): | ||||
| info = self._tinfo[h] | info = self._tinfo[h] | ||||
| h2v[h] = graph.make_h2d( | h2v[h] = graph.make_h2d( | ||||
| dtype=info.dtype, device=dumped_device, shape=info.shape, name=k | |||||
| dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k | |||||
| ) | ) | ||||
| for op, ihandles, ohandles in self._seq: | for op, ihandles, ohandles in self._seq: | ||||
| @@ -919,6 +920,7 @@ class CompiledTensorProxy(RawTensor): | |||||
| def __init__(self, handle): | def __init__(self, handle): | ||||
| self.__handle = handle | self.__handle = handle | ||||
| self._isscalar = False | |||||
| self.__info = active_trace._tinfo[handle] | self.__info = active_trace._tinfo[handle] | ||||
| self.__shape = None | self.__shape = None | ||||
| self.__data = None | self.__data = None | ||||
| @@ -934,6 +936,8 @@ class CompiledTensorProxy(RawTensor): | |||||
| @property | @property | ||||
| def shape(self): | def shape(self): | ||||
| if self._isscalar: | |||||
| return () | |||||
| if self.__shape is None: | if self.__shape is None: | ||||
| if self.__info.shape_read: | if self.__info.shape_read: | ||||
| self.__shape = self.__info.shape_reader.get_value().shape | self.__shape = self.__info.shape_reader.get_value().shape | ||||
| @@ -951,6 +955,8 @@ class CompiledTensorProxy(RawTensor): | |||||
| self.__value = self._dev_tensor().numpy() | self.__value = self._dev_tensor().numpy() | ||||
| else: | else: | ||||
| raise TraceMismatchError("value of this tensor is not read in trace") | raise TraceMismatchError("value of this tensor is not read in trace") | ||||
| if self._isscalar: | |||||
| self.__value = self.__value.squeeze() | |||||
| return self.__value | return self.__value | ||||
| def _dev_tensor(self): | def _dev_tensor(self): | ||||
| @@ -970,9 +976,10 @@ class CompiledTensorProxy(RawTensor): | |||||
| class LazyEvalTensor(RawTensor): | class LazyEvalTensor(RawTensor): | ||||
| def __init__(self, varnode): | |||||
| super(LazyEvalTensor, self).__init__() | |||||
| def __init__(self, varnode, isscalar=False): | |||||
| super().__init__() | |||||
| self.__varnode = varnode | self.__varnode = varnode | ||||
| self._isscalar = isscalar | |||||
| @property | @property | ||||
| def dtype(self): | def dtype(self): | ||||
| @@ -984,10 +991,15 @@ class LazyEvalTensor(RawTensor): | |||||
| @property | @property | ||||
| def shape(self): | def shape(self): | ||||
| if self._isscalar: | |||||
| return () | |||||
| return self.__varnode.shape | return self.__varnode.shape | ||||
| def numpy(self): | def numpy(self): | ||||
| return self.__varnode.value | |||||
| ret = self.__varnode.value | |||||
| if self._isscalar: | |||||
| ret = ret.squeeze() | |||||
| return ret | |||||
| def _dev_tensor(self): | def _dev_tensor(self): | ||||
| raise RuntimeError("cannot access data during symbolic tracing") | raise RuntimeError("cannot access data during symbolic tracing") | ||||
| @@ -1041,10 +1053,12 @@ class TracedLazyTensor(TraceMixin, LazyEvalTensor): | |||||
| def assign_raw_tensor(lhs, rhs): | def assign_raw_tensor(lhs, rhs): | ||||
| handle = rhs._handle | handle = rhs._handle | ||||
| # Keep isscalar of lhs | |||||
| isscalar = lhs._isscalar | |||||
| rhs.__dict__.clear() | rhs.__dict__.clear() | ||||
| lhs.__dict__.clear() | lhs.__dict__.clear() | ||||
| lhs.__class__ = RawTensor | lhs.__class__ = RawTensor | ||||
| lhs.__init__(handle) | |||||
| lhs.__init__(handle, isscalar=isscalar) | |||||
| # this hook turns RawTensor into LazyEvalTensor | # this hook turns RawTensor into LazyEvalTensor | ||||
| @@ -1060,7 +1074,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||||
| data_setter = G.InputNode( | data_setter = G.InputNode( | ||||
| device=x.device, | device=x.device, | ||||
| dtype=x.dtype, | dtype=x.dtype, | ||||
| shape=x.shape, | |||||
| shape=x.shape or (1,), | |||||
| graph=graph, | graph=graph, | ||||
| use_static_shape=True, | use_static_shape=True, | ||||
| ) | ) | ||||
| @@ -1091,7 +1105,9 @@ apply.disable(apply_symbolic_mode) | |||||
| @apply.register() | @apply.register() | ||||
| def apply_const_symbolic_mode(op: Const, *args: RawTensor): | def apply_const_symbolic_mode(op: Const, *args: RawTensor): | ||||
| graph = active_trace._lazy_eval_graph | graph = active_trace._lazy_eval_graph | ||||
| ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) | |||||
| ret = LazyEvalTensor( | |||||
| graph.make_const(op.value, dtype=op.dtype, device=op.device), isscalar=True | |||||
| ) | |||||
| active_trace._lazy_eval_tensors.add(ret) | active_trace._lazy_eval_tensors.add(ret) | ||||
| return (ret,) | return (ret,) | ||||
| @@ -46,9 +46,9 @@ class Observer(Module): | |||||
| def get_dtype(self): | def get_dtype(self): | ||||
| q_dict = self.get_qparams() | q_dict = self.get_qparams() | ||||
| numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0] | |||||
| numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy() | |||||
| numpy_zero_point = ( | numpy_zero_point = ( | ||||
| None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0] | |||||
| None if "zero_point" not in q_dict else q_dict["zero_point"].numpy() | |||||
| ) | ) | ||||
| return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point) | return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point) | ||||
| @@ -18,7 +18,7 @@ from megengine.module import Module | |||||
| class Simple(Module): | class Simple(Module): | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| self.a = Parameter(1.0, dtype=np.float32) | |||||
| self.a = Parameter([1.0], dtype=np.float32) | |||||
| def forward(self, x, y): | def forward(self, x, y): | ||||
| x = x[y] * self.a | x = x[y] * self.a | ||||
| @@ -28,7 +28,7 @@ class Simple(Module): | |||||
| class Simple2(Module): | class Simple2(Module): | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| self.a = Parameter(1.0, dtype=np.float32) | |||||
| self.a = Parameter([1.0], dtype=np.float32) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| x = x[1, ..., :, 0:4:2, 0:2] * self.a | x = x[1, ..., :, 0:4:2, 0:2] * self.a | ||||
| @@ -18,7 +18,7 @@ from megengine.module import Module | |||||
| class Simple(Module): | class Simple(Module): | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| self.a = Parameter(1.0, dtype=np.float32) | |||||
| self.a = Parameter([1.0], dtype=np.float32) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| x = x[:, 0] * self.a | x = x[:, 0] * self.a | ||||
| @@ -18,8 +18,8 @@ from megengine.module import Module | |||||
| class Simple(Module): | class Simple(Module): | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| self.a = Parameter(1.0, dtype=np.float32) | |||||
| self.b = Parameter(1.0, dtype=np.float32) | |||||
| self.a = Parameter([1.0], dtype=np.float32) | |||||
| self.b = Parameter([1.0], dtype=np.float32) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| x = x * self.a | x = x * self.a | ||||
| @@ -21,7 +21,7 @@ from megengine.module import Module | |||||
| class Simple(Module): | class Simple(Module): | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| self.a = Parameter(1.23, dtype=np.float32) | |||||
| self.a = Parameter([1.23], dtype=np.float32) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| x = x * self.a | x = x * self.a | ||||
| @@ -18,7 +18,7 @@ from megengine.optimizer import SGD, MultiStepLR | |||||
| class Simple(Module): | class Simple(Module): | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| self.a = Parameter(1.23, dtype=np.float32) | |||||
| self.a = Parameter([1.23], dtype=np.float32) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| x = x * self.a | x = x * self.a | ||||
| @@ -32,7 +32,7 @@ class MLP(Module): | |||||
| class Simple(Module): | class Simple(Module): | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| self.a = Parameter(1.23, dtype=np.float32) | |||||
| self.a = Parameter([1.23], dtype=np.float32) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| x = x * self.a | x = x * self.a | ||||
| @@ -11,7 +11,7 @@ from megengine.module import Module | |||||
| class Simple(Module): | class Simple(Module): | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| self.a = Parameter(1.23, dtype=np.float32) | |||||
| self.a = Parameter([1.23], dtype=np.float32) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| x = x * self.a | x = x * self.a | ||||
| @@ -19,7 +19,7 @@ from megengine.module import Module | |||||
| class Simple(Module): | class Simple(Module): | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| self.a = Parameter(1.23, dtype=np.float32) | |||||
| self.a = Parameter([1.23], dtype=np.float32) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| x = x * self.a | x = x * self.a | ||||
| @@ -107,7 +107,7 @@ def test_xornet_trace_dump(): | |||||
| if step % 50 == 0: | if step % 50 == 0: | ||||
| minibatch = next(val_dataset) | minibatch = next(val_dataset) | ||||
| _, loss = val_fun(data, label) | _, loss = val_fun(data, label) | ||||
| loss = loss.numpy()[0] | |||||
| loss = loss.numpy() | |||||
| val_loss.append((step, loss)) | val_loss.append((step, loss)) | ||||
| print("Step: {} loss={}".format(step, loss)) | print("Step: {} loss={}".format(step, loss)) | ||||
| opt.step() | opt.step() | ||||
| @@ -449,7 +449,7 @@ def test_advance_indexing_high_level(): | |||||
| y = np.array([1, 2]) | y = np.array([1, 2]) | ||||
| yy = Tensor(y) | yy = Tensor(y) | ||||
| np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) | np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) | ||||
| # np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) # FIXME | |||||
| np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) | |||||
| np.testing.assert_equal(x[:, y], xx[:, y].numpy()) | np.testing.assert_equal(x[:, y], xx[:, y].numpy()) | ||||
| np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) | np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) | ||||
| @@ -469,10 +469,9 @@ def test_advance_indexing_high_level(): | |||||
| y = np.array([1]) | y = np.array([1]) | ||||
| yy = Tensor(y) | yy = Tensor(y) | ||||
| np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) | np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) | ||||
| # np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) # FIXME | |||||
| np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) | |||||
| np.testing.assert_equal(x[:, y], xx[:, y].numpy()) | np.testing.assert_equal(x[:, y], xx[:, y].numpy()) | ||||
| # XXX: no way to tell whether yy is scalar or ndim=1 array | |||||
| np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) | np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) | ||||
| x = np.arange(9).reshape(3, 3).astype("int32") | x = np.arange(9).reshape(3, 3).astype("int32") | ||||
| @@ -21,6 +21,7 @@ from megengine.core.ops import builtin as ops | |||||
| from megengine.core.ops.builtin import Elemwise | from megengine.core.ops.builtin import Elemwise | ||||
| from megengine.core.tensor.core import apply | from megengine.core.tensor.core import apply | ||||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | from megengine.core.tensor.raw_tensor import as_raw_tensor | ||||
| from megengine.core.tensor.utils import isscalar | |||||
| from megengine.functional import exp, log | from megengine.functional import exp, log | ||||
| from megengine.jit import exclude_from_trace, trace | from megengine.jit import exclude_from_trace, trace | ||||
| from megengine.random import normal, uniform | from megengine.random import normal, uniform | ||||
| @@ -263,20 +264,21 @@ def test_optimize_for_inference_broadcast(): | |||||
| def test_trace_cvt_bool(): | def test_trace_cvt_bool(): | ||||
| set_symbolic_shape(True) | |||||
| x = tensor([0], dtype=np.int32) | x = tensor([0], dtype=np.int32) | ||||
| @trace(symbolic=True) | @trace(symbolic=True) | ||||
| def f(x): | def f(x): | ||||
| return x.shape[0] == 0 | |||||
| a = x.shape | |||||
| b = a[0] | |||||
| assert isscalar(b) | |||||
| return b == 0 | |||||
| for i in range(3): | for i in range(3): | ||||
| np.testing.assert_equal(f(x).numpy()[0], False) | |||||
| np.testing.assert_equal(f(x).numpy(), False) | |||||
| def test_trace_reshape(): | def test_trace_reshape(): | ||||
| for symbolic in [False, True]: | for symbolic in [False, True]: | ||||
| set_symbolic_shape(True) | |||||
| x1 = tensor(np.random.randn(2, 10, 10)) | x1 = tensor(np.random.randn(2, 10, 10)) | ||||
| x2 = tensor(np.random.randn(4, 10, 10)) | x2 = tensor(np.random.randn(4, 10, 10)) | ||||
| x3 = tensor(np.random.randn(8, 10, 10)) | x3 = tensor(np.random.randn(8, 10, 10)) | ||||
| @@ -359,7 +361,6 @@ def test_raise_on_trace(): | |||||
| def test_trace_broadcast(): | def test_trace_broadcast(): | ||||
| for symbolic in [False, True]: | for symbolic in [False, True]: | ||||
| set_symbolic_shape(True) | |||||
| x1 = tensor(np.random.randn(3, 1, 1)) | x1 = tensor(np.random.randn(3, 1, 1)) | ||||
| x2 = tensor(np.random.randn(1, 4, 1)) | x2 = tensor(np.random.randn(1, 4, 1)) | ||||
| x3 = tensor(np.random.randn(1, 1, 5)) | x3 = tensor(np.random.randn(1, 1, 5)) | ||||
| @@ -397,7 +398,6 @@ def test_trace_nms(): | |||||
| def test_trace_valid_broadcast(): | def test_trace_valid_broadcast(): | ||||
| set_symbolic_shape(True) | |||||
| x1 = tensor(np.random.randn(1, 1)) | x1 = tensor(np.random.randn(1, 1)) | ||||
| x2 = tensor(np.random.randn(1, 2)) | x2 = tensor(np.random.randn(1, 2)) | ||||
| shape = (tensor([2]), tensor([2])) | shape = (tensor([2]), tensor([2])) | ||||
| @@ -0,0 +1,52 @@ | |||||
| import numpy as np | |||||
| import megengine.functional as F | |||||
| from megengine import Tensor | |||||
| from megengine.core._trace_option import use_symbolic_shape | |||||
| def test_zero_dim(): | |||||
| a = Tensor(1) | |||||
| a_np = np.array(1, dtype=np.int32) | |||||
| np.testing.assert_equal(a, a_np) | |||||
| if use_symbolic_shape(): | |||||
| np.testing.assert_equal(a.shape, np.array(a_np.shape)) | |||||
| else: | |||||
| np.testing.assert_equal(a.shape, a_np.shape) | |||||
| def test_sum(): | |||||
| a = Tensor([1, 2]) | |||||
| a = a.reshape((1, 2)) | |||||
| assert a.sum().ndim == 0 | |||||
| assert a.sum(axis=1).ndim == 1 | |||||
| def test_max(): | |||||
| a = Tensor([1, 2]) | |||||
| a = a.reshape((1, 2)) | |||||
| assert a.max().ndim == 0 | |||||
| assert a.max(axis=1).ndim == 1 | |||||
| def test_reshape(): | |||||
| a = Tensor(1) | |||||
| a = a.reshape((1, 1)) | |||||
| def test_squeeze(): | |||||
| a = Tensor(1) | |||||
| a = a.reshape((1, 1)) | |||||
| assert F.squeeze(a).ndim == 0 | |||||
| def test_elemementwise(): | |||||
| a = Tensor(1.0) | |||||
| assert F.exp(a).ndim == 0 | |||||
| assert (a + a).ndim == 0 | |||||
| assert (a + 1).ndim == 0 | |||||
| def test_astype(): | |||||
| a = Tensor(1.0) | |||||
| assert a.astype("int32").ndim == 0 | |||||