| @@ -26,4 +26,6 @@ def set_symbolic_shape(option: bool): | |||
| """ Sets whether tensor.shape returns a tensor instead of a tuple | |||
| """ | |||
| global _use_symbolic_shape | |||
| _org = _use_symbolic_shape | |||
| _use_symbolic_shape = option | |||
| return _org | |||
| @@ -14,7 +14,7 @@ from .._trace_option import use_symbolic_shape | |||
| from ..ops import builtin | |||
| from ..ops.special import Const | |||
| from .core import TensorBase, TensorWrapperBase, apply | |||
| from .utils import astensor1d, make_shape_tuple | |||
| from .utils import astensor1d, isscalar, make_shape_tuple | |||
| def remove_ellipsis(tensor, tuple_val): | |||
| @@ -89,9 +89,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
| if not isinstance(tuple_val, tuple): | |||
| tuple_val = (tuple_val,) | |||
| ndim_indexed = 0 | |||
| ndim_indexed_scalar = 0 | |||
| for i in tuple_val: | |||
| if not i is Ellipsis: | |||
| ndim_indexed += 1 if not hasattr(i, "ndim") else i.ndim | |||
| if isscalar(i): | |||
| ndim_indexed_scalar += 1 | |||
| if ndim_indexed > inp.ndim: | |||
| raise IndexError( | |||
| "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( | |||
| @@ -103,15 +107,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
| use_subtensor = True | |||
| inp, tuple_val = check_bool_index(inp, tuple_val) | |||
| def is_scalar(d): | |||
| if isinstance(i, int): | |||
| return True | |||
| if type(d).__module__ == np.__name__: | |||
| return np.isscalar(d) | |||
| # if isinstance(d, (TensorBase, TensorWrapperBase)): | |||
| # return d.shape == (1,) | |||
| return False | |||
| new_axes = [] | |||
| tensors = [] | |||
| items = [] | |||
| @@ -134,7 +129,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
| continue | |||
| if ( | |||
| not is_scalar(i) | |||
| not isscalar(i) | |||
| and not i is np.newaxis | |||
| and not i is Ellipsis | |||
| and not isinstance(i, slice) | |||
| @@ -191,7 +186,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
| items.append(item) | |||
| if new_axes: | |||
| raise IndexError("newaxis is not allowed here") | |||
| return inp, tensors, items, use_subtensor | |||
| return inp, tensors, items, use_subtensor, ndim_indexed_scalar == inp.ndim | |||
| def try_condtake(tensor, index): | |||
| @@ -217,11 +212,11 @@ def getitem(tensor, index): | |||
| try_result = try_condtake(tensor, index) | |||
| if len(try_result) == 2: | |||
| return try_result[0] | |||
| tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) | |||
| tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) | |||
| for v in tensors: | |||
| if isinstance(v.shape, v.__class__): | |||
| break | |||
| if v.shape[0] == 0: | |||
| if len(v.shape) > 0 and v.shape[0] == 0: | |||
| (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( | |||
| tensor | |||
| ) | |||
| @@ -231,6 +226,8 @@ def getitem(tensor, index): | |||
| else: | |||
| op = builtin.IndexingMultiAxisVec(items=items) | |||
| (result,) = apply(op, tensor, *tensors) | |||
| if ret_scalar: | |||
| result.__wrapped__._data._isscalar = True | |||
| return result | |||
| @@ -245,9 +242,9 @@ def setitem(tensor, index, value): | |||
| if not isinstance(value, (TensorBase, TensorWrapperBase)): | |||
| op = Const(value, dtype=tensor.dtype, device=tensor.device) | |||
| (value,) = op(tensor) | |||
| tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) | |||
| tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | |||
| for v in tensors: | |||
| if v.shape[0] == 0: | |||
| if len(v.shape) > 0 and v.shape[0] == 0: | |||
| return tensor | |||
| if use_subtensor: | |||
| op = builtin.Subtensor(items=items) | |||
| @@ -102,8 +102,9 @@ class Graph(_imperative_rt.ComputingGraph): | |||
| class VarNode(TensorBase): | |||
| def __init__(self, node: _imperative_rt.VarNode): | |||
| def __init__(self, node: _imperative_rt.VarNode, isscalar=False): | |||
| self._node = node | |||
| self._isscalar = isscalar | |||
| if hasattr(self.graph, "_var_cache"): | |||
| self.graph._var_cache[node] = self | |||
| @@ -33,8 +33,9 @@ class RawTensor(TensorBase): | |||
| _del_cb = None | |||
| _handle = None | |||
| def __init__(self, handle=None): | |||
| def __init__(self, handle=None, isscalar=False): | |||
| self._handle = handle | |||
| self._isscalar = isscalar | |||
| if handle is not None: | |||
| if self._init_cb: | |||
| self._init_cb() | |||
| @@ -49,10 +50,15 @@ class RawTensor(TensorBase): | |||
| @property | |||
| def shape(self): | |||
| if self._isscalar: | |||
| return () | |||
| return get_shape(self._handle) | |||
| def numpy(self): | |||
| return get_value(self._handle) | |||
| ret = get_value(self._handle) | |||
| if self._isscalar: | |||
| ret = ret.squeeze() | |||
| return ret | |||
| def _dev_tensor(self): | |||
| return _get_dev_tensor(self._handle) | |||
| @@ -102,7 +108,7 @@ def _(array: np.ndarray, dtype=None, device=None): | |||
| device = None if device is None else as_device(device).to_c() | |||
| if 0 in array.strides: | |||
| array = array.squeeze().reshape(array.shape) | |||
| return RawTensor(put(array, dtype=dtype, device=device)) | |||
| return RawTensor(put(array, dtype=dtype, device=device), isscalar=(array.ndim == 0)) | |||
| @as_raw_tensor.register(RawTensor) | |||
| @@ -21,7 +21,9 @@ from .indexing import getitem as _getitem | |||
| from .indexing import setitem as _setitem | |||
| from .raw_tensor import RawTensor, as_raw_tensor | |||
| from .tensor import Tensor | |||
| from .utils import isscalar | |||
| from .utils import make_shape_tuple as _make_shape_tuple | |||
| from .utils import setscalar | |||
| _ElwMod = Elemwise.Mode | |||
| @@ -39,6 +41,13 @@ def _elwise(*args, mode): | |||
| ) | |||
| args = utils.convert_inputs(*args) | |||
| (result,) = apply(op, *args) | |||
| _isscalar = True | |||
| for i in args: | |||
| if isscalar(i) == False: | |||
| _isscalar = False | |||
| break | |||
| if _isscalar: | |||
| setscalar(result) | |||
| return result | |||
| @@ -153,6 +162,8 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
| param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis)) | |||
| op = builtin.AxisAddRemove(param=param) | |||
| (result,) = apply(op, inp) | |||
| if len(axis) == inp.ndim: | |||
| setscalar(result) | |||
| return result | |||
| @@ -189,6 +200,8 @@ def _reduce(mode): | |||
| if self.dtype == np.bool_: | |||
| if mode in ["MIN", "MAX"]: | |||
| result = result.astype("bool") | |||
| if axis is None or self.ndim == 1: | |||
| setscalar(result) | |||
| return result | |||
| return f | |||
| @@ -321,9 +334,7 @@ class ArrayMethodMixin(abc.ABC): | |||
| __complex__ = lambda self: complex(self.item()) | |||
| def __len__(self): | |||
| shape = self.shape | |||
| if use_symbolic_shape(): | |||
| shape = shape.numpy() | |||
| shape = self.__wrapped__.shape | |||
| if shape: | |||
| return int(shape[0]) | |||
| raise TypeError("ndim is 0") | |||
| @@ -344,18 +355,17 @@ class ArrayMethodMixin(abc.ABC): | |||
| @property | |||
| def ndim(self): | |||
| shape = self.shape | |||
| if isinstance(shape, self.__class__): | |||
| # XXX: assume ndim is not changed during trace | |||
| ndim = shape.__wrapped__.shape[0] | |||
| return ndim | |||
| shape = self.__wrapped__.shape | |||
| if shape is None: | |||
| raise ValueError("unkown ndim") | |||
| return len(shape) | |||
| @property | |||
| def size(self): | |||
| if use_symbolic_shape(): | |||
| return self.shape.prod() | |||
| return np.prod(self.shape).item() | |||
| shape = self.shape | |||
| if shape.__class__ is tuple: | |||
| return np.prod(self.shape).item() | |||
| return shape.prod() | |||
| @property | |||
| def T(self): | |||
| @@ -416,8 +426,8 @@ class ArrayMethodMixin(abc.ABC): | |||
| .. testoutput:: | |||
| [2] | |||
| [10.] | |||
| 2 | |||
| 10. | |||
| """ | |||
| return _reduce("SUM")(self, axis, keepdims) | |||
| @@ -444,10 +454,10 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase): | |||
| @property | |||
| def shape(self): | |||
| if use_symbolic_shape(): | |||
| return apply(GetVarShape(), self)[0] | |||
| else: | |||
| return self.__wrapped__.shape | |||
| shape = self.__wrapped__.shape | |||
| if shape == () or not use_symbolic_shape(): | |||
| return shape | |||
| return apply(GetVarShape(), self)[0] | |||
| @property | |||
| def device(self): | |||
| @@ -133,7 +133,9 @@ def concatenate(inputs, axis=0, *, device=None): | |||
| def astype(x, dtype): | |||
| dtype = np.dtype(dtype) | |||
| if not is_equal(x.dtype, dtype): | |||
| isscalar = x.__wrapped__._data._isscalar | |||
| (x,) = apply(builtin.TypeCvt(param=dtype), x) | |||
| x.__wrapped__._data._isscalar = isscalar | |||
| return x | |||
| @@ -176,13 +178,29 @@ def result_type(*args): | |||
| def isscalar(x): | |||
| try: | |||
| return x.ndim == 0 | |||
| except: | |||
| pass | |||
| if isinstance(x, TensorWrapperBase): | |||
| x = x.__wrapped__ | |||
| if hasattr(x, "_isscalar"): | |||
| return x._isscalar | |||
| if isinstance(x, TensorBase): | |||
| return x._data._isscalar | |||
| return np.isscalar(x) | |||
| def setscalar(x): | |||
| if isinstance(x, TensorWrapperBase): | |||
| x = x.__wrapped__ | |||
| if hasattr(x, "_isscalar"): | |||
| x._isscalar = True | |||
| elif isinstance(x, TensorBase): | |||
| x._data._isscalar = True | |||
| else: | |||
| raise NotImplementedError("Unsupport type {}".format(type(x))) | |||
| def astensor1d(x, *reference, dtype=None, device=None): | |||
| """ | |||
| Convert something to 1D tensor. Support following types | |||
| @@ -195,8 +213,8 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
| except AttributeError: | |||
| pass | |||
| else: | |||
| if ndim != 1: | |||
| raise ValueError("ndim != 1: %d" % ndim) | |||
| if ndim != 0 and ndim != 1: | |||
| raise ValueError("ndim != 1 or 0, get : %d" % ndim) | |||
| if not isinstance(x, (TensorBase, TensorWrapperBase)): | |||
| (x,) = Const(x, dtype=dtype, device=device)(*reference) | |||
| return x | |||
| @@ -216,7 +234,11 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
| def _expand_int(s, i): | |||
| if isinstance(i, (TensorBase, TensorWrapperBase)): | |||
| s += list(i.numpy()) | |||
| i_np = i.numpy() | |||
| if i_np.ndim == 0: | |||
| s.append(int(i_np)) | |||
| else: | |||
| s += list(i_np) | |||
| return | |||
| if isinstance(i, Iterable): | |||
| for ii in i: | |||
| @@ -63,8 +63,12 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): | |||
| """ | |||
| op = ParamPackSplit() | |||
| op.offsets = offsets | |||
| op.shapes = shapes | |||
| return apply(op, inp) | |||
| op.shapes = [s or (1,) for s in shapes] | |||
| outputs = apply(op, inp) | |||
| for s, x in zip(shapes, outputs): | |||
| if not s: | |||
| x._isscalar = True | |||
| return outputs | |||
| def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list): | |||
| @@ -13,6 +13,7 @@ from ..core.ops import builtin | |||
| from ..core.ops.builtin import Elemwise | |||
| from ..core.tensor import megbrain_graph, utils | |||
| from ..core.tensor.core import apply | |||
| from ..core.tensor.utils import isscalar, setscalar | |||
| from ..device import get_default_device | |||
| from ..jit.tracing import is_tracing | |||
| from ..tensor import Tensor | |||
| @@ -105,7 +106,14 @@ def _elwise(*args, mode): | |||
| args = utils.convert_inputs(*args) | |||
| if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"): | |||
| args = tuple(map(lambda x: x.astype("float32"), args)) | |||
| _isscalar = True | |||
| for i in args: | |||
| if isscalar(i) == False: | |||
| _isscalar = False | |||
| break | |||
| (result,) = apply(op, *args) | |||
| if _isscalar: | |||
| setscalar(result) | |||
| return result | |||
| @@ -63,7 +63,7 @@ def l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
| .. testoutput:: | |||
| [2.75] | |||
| 2.75 | |||
| """ | |||
| diff = pred - label | |||
| @@ -115,7 +115,7 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
| .. testoutput:: | |||
| [9.75] | |||
| 9.75 | |||
| """ | |||
| diff = pred - label | |||
| @@ -170,7 +170,7 @@ def cross_entropy( | |||
| .. testoutput:: | |||
| [0.6931] | |||
| 0.6931 | |||
| """ | |||
| n0 = pred.ndim | |||
| @@ -226,7 +226,7 @@ def binary_cross_entropy( | |||
| .. testoutput:: | |||
| [0.6931] | |||
| 0.6931 | |||
| """ | |||
| if not with_logits: | |||
| @@ -265,7 +265,7 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: | |||
| .. testoutput:: | |||
| [1.5] | |||
| 1.5 | |||
| """ | |||
| assert norm in ["L1", "L2"], "norm must be L1 or L2" | |||
| @@ -155,7 +155,7 @@ def sum( | |||
| .. testoutput:: | |||
| [21] | |||
| 21 | |||
| """ | |||
| return inp.sum(axis=axis, keepdims=keepdims) | |||
| @@ -189,7 +189,7 @@ def prod( | |||
| .. testoutput:: | |||
| [720] | |||
| 720 | |||
| """ | |||
| return inp.prod(axis=axis, keepdims=keepdims) | |||
| @@ -226,7 +226,7 @@ def mean( | |||
| .. testoutput:: | |||
| [3.5] | |||
| 3.5 | |||
| """ | |||
| return inp.mean(axis=axis, keepdims=keepdims) | |||
| @@ -263,7 +263,7 @@ def var( | |||
| .. testoutput:: | |||
| [2.9167] | |||
| 2.9167 | |||
| """ | |||
| if axis is None: | |||
| m = mean(inp, axis=axis, keepdims=False) | |||
| @@ -340,7 +340,7 @@ def min( | |||
| .. testoutput:: | |||
| [1] | |||
| 1 | |||
| """ | |||
| return inp.min(axis=axis, keepdims=keepdims) | |||
| @@ -377,7 +377,7 @@ def max( | |||
| .. testoutput:: | |||
| [6] | |||
| 6 | |||
| """ | |||
| return inp.max(axis=axis, keepdims=keepdims) | |||
| @@ -412,7 +412,7 @@ def norm( | |||
| .. testoutput:: | |||
| [4.3589] | |||
| 4.3589 | |||
| """ | |||
| if axis is None: | |||
| @@ -460,7 +460,7 @@ def argmin( | |||
| .. testoutput:: | |||
| [0] | |||
| 0 | |||
| """ | |||
| if isinstance(axis, collections.abc.Iterable): | |||
| @@ -519,7 +519,7 @@ def argmax( | |||
| .. testoutput:: | |||
| [5] | |||
| 5 | |||
| """ | |||
| if isinstance(axis, collections.abc.Iterable): | |||
| @@ -111,6 +111,8 @@ def full(shape, value, dtype="float32", device=None): | |||
| (x,) = Const(value, dtype=dtype, device=device)( | |||
| Tensor(value, dtype=dtype, device=device) | |||
| ) | |||
| if len(shape) == 0: # scalar | |||
| return x | |||
| return broadcast_to(x, shape) | |||
| @@ -53,7 +53,7 @@ def topk_accuracy( | |||
| .. testoutput:: | |||
| [0.] [0.375] | |||
| 0.0 0.375 | |||
| """ | |||
| if isinstance(topk, int): | |||
| topk = (topk,) | |||
| @@ -168,8 +168,6 @@ class trace: | |||
| self._output_bindings = None | |||
| self._output_names = None | |||
| set_symbolic_shape(self._symbolic_shape) | |||
| def _new_handle(self): | |||
| handle = len(self._tinfo) | |||
| info = TensorInfo() | |||
| @@ -368,6 +366,7 @@ class trace: | |||
| interrupted = False | |||
| def do_enter(): | |||
| self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape) | |||
| self._set_active(True) | |||
| if self._untraced: | |||
| self._init_trace(self._symbolic) | |||
| @@ -423,6 +422,8 @@ class trace: | |||
| apply.disable(apply_compiled_mode) | |||
| apply.disable(apply_const_compiled_mode) | |||
| self._set_active(False) | |||
| # Restore global variable | |||
| set_symbolic_shape(self._save_symbolic_shape) | |||
| def do_exit(): | |||
| if not self._untraced and self._pc != len(self._seq): | |||
| @@ -498,7 +499,7 @@ class trace: | |||
| opnode = info.data_setter = G.InputNode( | |||
| device=info.device, | |||
| dtype=info.dtype, | |||
| shape=info.shape, | |||
| shape=info.shape or (1,), | |||
| graph=graph, | |||
| use_static_shape=_input_node_use_static_shape(), | |||
| ) | |||
| @@ -544,7 +545,7 @@ class trace: | |||
| *links, | |||
| device=info.device, | |||
| dtype=info.dtype, | |||
| shape=info.shape, | |||
| shape=info.shape or (1,), | |||
| graph=graph, | |||
| use_static_shape=_input_node_use_static_shape(), | |||
| ) | |||
| @@ -719,13 +720,13 @@ class trace: | |||
| h2v[h] = graph.make_h2d( | |||
| dtype=info.dtype, | |||
| device=dumped_device, | |||
| shape=info.shape, | |||
| shape=info.shape or (1,), | |||
| name=arg_names[i] if arg_names else None, | |||
| ) | |||
| for k, h in self._kwarg_bindings.items(): | |||
| info = self._tinfo[h] | |||
| h2v[h] = graph.make_h2d( | |||
| dtype=info.dtype, device=dumped_device, shape=info.shape, name=k | |||
| dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k | |||
| ) | |||
| for op, ihandles, ohandles in self._seq: | |||
| @@ -919,6 +920,7 @@ class CompiledTensorProxy(RawTensor): | |||
| def __init__(self, handle): | |||
| self.__handle = handle | |||
| self._isscalar = False | |||
| self.__info = active_trace._tinfo[handle] | |||
| self.__shape = None | |||
| self.__data = None | |||
| @@ -934,6 +936,8 @@ class CompiledTensorProxy(RawTensor): | |||
| @property | |||
| def shape(self): | |||
| if self._isscalar: | |||
| return () | |||
| if self.__shape is None: | |||
| if self.__info.shape_read: | |||
| self.__shape = self.__info.shape_reader.get_value().shape | |||
| @@ -951,6 +955,8 @@ class CompiledTensorProxy(RawTensor): | |||
| self.__value = self._dev_tensor().numpy() | |||
| else: | |||
| raise TraceMismatchError("value of this tensor is not read in trace") | |||
| if self._isscalar: | |||
| self.__value = self.__value.squeeze() | |||
| return self.__value | |||
| def _dev_tensor(self): | |||
| @@ -970,9 +976,10 @@ class CompiledTensorProxy(RawTensor): | |||
| class LazyEvalTensor(RawTensor): | |||
| def __init__(self, varnode): | |||
| super(LazyEvalTensor, self).__init__() | |||
| def __init__(self, varnode, isscalar=False): | |||
| super().__init__() | |||
| self.__varnode = varnode | |||
| self._isscalar = isscalar | |||
| @property | |||
| def dtype(self): | |||
| @@ -984,10 +991,15 @@ class LazyEvalTensor(RawTensor): | |||
| @property | |||
| def shape(self): | |||
| if self._isscalar: | |||
| return () | |||
| return self.__varnode.shape | |||
| def numpy(self): | |||
| return self.__varnode.value | |||
| ret = self.__varnode.value | |||
| if self._isscalar: | |||
| ret = ret.squeeze() | |||
| return ret | |||
| def _dev_tensor(self): | |||
| raise RuntimeError("cannot access data during symbolic tracing") | |||
| @@ -1041,10 +1053,12 @@ class TracedLazyTensor(TraceMixin, LazyEvalTensor): | |||
| def assign_raw_tensor(lhs, rhs): | |||
| handle = rhs._handle | |||
| # Keep isscalar of lhs | |||
| isscalar = lhs._isscalar | |||
| rhs.__dict__.clear() | |||
| lhs.__dict__.clear() | |||
| lhs.__class__ = RawTensor | |||
| lhs.__init__(handle) | |||
| lhs.__init__(handle, isscalar=isscalar) | |||
| # this hook turns RawTensor into LazyEvalTensor | |||
| @@ -1060,7 +1074,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||
| data_setter = G.InputNode( | |||
| device=x.device, | |||
| dtype=x.dtype, | |||
| shape=x.shape, | |||
| shape=x.shape or (1,), | |||
| graph=graph, | |||
| use_static_shape=True, | |||
| ) | |||
| @@ -1091,7 +1105,9 @@ apply.disable(apply_symbolic_mode) | |||
| @apply.register() | |||
| def apply_const_symbolic_mode(op: Const, *args: RawTensor): | |||
| graph = active_trace._lazy_eval_graph | |||
| ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) | |||
| ret = LazyEvalTensor( | |||
| graph.make_const(op.value, dtype=op.dtype, device=op.device), isscalar=True | |||
| ) | |||
| active_trace._lazy_eval_tensors.add(ret) | |||
| return (ret,) | |||
| @@ -46,9 +46,9 @@ class Observer(Module): | |||
| def get_dtype(self): | |||
| q_dict = self.get_qparams() | |||
| numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0] | |||
| numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy() | |||
| numpy_zero_point = ( | |||
| None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0] | |||
| None if "zero_point" not in q_dict else q_dict["zero_point"].numpy() | |||
| ) | |||
| return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point) | |||
| @@ -18,7 +18,7 @@ from megengine.module import Module | |||
| class Simple(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.a = Parameter(1.0, dtype=np.float32) | |||
| self.a = Parameter([1.0], dtype=np.float32) | |||
| def forward(self, x, y): | |||
| x = x[y] * self.a | |||
| @@ -28,7 +28,7 @@ class Simple(Module): | |||
| class Simple2(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.a = Parameter(1.0, dtype=np.float32) | |||
| self.a = Parameter([1.0], dtype=np.float32) | |||
| def forward(self, x): | |||
| x = x[1, ..., :, 0:4:2, 0:2] * self.a | |||
| @@ -18,7 +18,7 @@ from megengine.module import Module | |||
| class Simple(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.a = Parameter(1.0, dtype=np.float32) | |||
| self.a = Parameter([1.0], dtype=np.float32) | |||
| def forward(self, x): | |||
| x = x[:, 0] * self.a | |||
| @@ -18,8 +18,8 @@ from megengine.module import Module | |||
| class Simple(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.a = Parameter(1.0, dtype=np.float32) | |||
| self.b = Parameter(1.0, dtype=np.float32) | |||
| self.a = Parameter([1.0], dtype=np.float32) | |||
| self.b = Parameter([1.0], dtype=np.float32) | |||
| def forward(self, x): | |||
| x = x * self.a | |||
| @@ -21,7 +21,7 @@ from megengine.module import Module | |||
| class Simple(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.a = Parameter(1.23, dtype=np.float32) | |||
| self.a = Parameter([1.23], dtype=np.float32) | |||
| def forward(self, x): | |||
| x = x * self.a | |||
| @@ -18,7 +18,7 @@ from megengine.optimizer import SGD, MultiStepLR | |||
| class Simple(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.a = Parameter(1.23, dtype=np.float32) | |||
| self.a = Parameter([1.23], dtype=np.float32) | |||
| def forward(self, x): | |||
| x = x * self.a | |||
| @@ -32,7 +32,7 @@ class MLP(Module): | |||
| class Simple(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.a = Parameter(1.23, dtype=np.float32) | |||
| self.a = Parameter([1.23], dtype=np.float32) | |||
| def forward(self, x): | |||
| x = x * self.a | |||
| @@ -11,7 +11,7 @@ from megengine.module import Module | |||
| class Simple(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.a = Parameter(1.23, dtype=np.float32) | |||
| self.a = Parameter([1.23], dtype=np.float32) | |||
| def forward(self, x): | |||
| x = x * self.a | |||
| @@ -19,7 +19,7 @@ from megengine.module import Module | |||
| class Simple(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.a = Parameter(1.23, dtype=np.float32) | |||
| self.a = Parameter([1.23], dtype=np.float32) | |||
| def forward(self, x): | |||
| x = x * self.a | |||
| @@ -107,7 +107,7 @@ def test_xornet_trace_dump(): | |||
| if step % 50 == 0: | |||
| minibatch = next(val_dataset) | |||
| _, loss = val_fun(data, label) | |||
| loss = loss.numpy()[0] | |||
| loss = loss.numpy() | |||
| val_loss.append((step, loss)) | |||
| print("Step: {} loss={}".format(step, loss)) | |||
| opt.step() | |||
| @@ -449,7 +449,7 @@ def test_advance_indexing_high_level(): | |||
| y = np.array([1, 2]) | |||
| yy = Tensor(y) | |||
| np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) | |||
| # np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) # FIXME | |||
| np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) | |||
| np.testing.assert_equal(x[:, y], xx[:, y].numpy()) | |||
| np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) | |||
| @@ -469,10 +469,9 @@ def test_advance_indexing_high_level(): | |||
| y = np.array([1]) | |||
| yy = Tensor(y) | |||
| np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) | |||
| # np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) # FIXME | |||
| np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) | |||
| np.testing.assert_equal(x[:, y], xx[:, y].numpy()) | |||
| # XXX: no way to tell whether yy is scalar or ndim=1 array | |||
| np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) | |||
| x = np.arange(9).reshape(3, 3).astype("int32") | |||
| @@ -21,6 +21,7 @@ from megengine.core.ops import builtin as ops | |||
| from megengine.core.ops.builtin import Elemwise | |||
| from megengine.core.tensor.core import apply | |||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
| from megengine.core.tensor.utils import isscalar | |||
| from megengine.functional import exp, log | |||
| from megengine.jit import exclude_from_trace, trace | |||
| from megengine.random import normal, uniform | |||
| @@ -263,20 +264,21 @@ def test_optimize_for_inference_broadcast(): | |||
| def test_trace_cvt_bool(): | |||
| set_symbolic_shape(True) | |||
| x = tensor([0], dtype=np.int32) | |||
| @trace(symbolic=True) | |||
| def f(x): | |||
| return x.shape[0] == 0 | |||
| a = x.shape | |||
| b = a[0] | |||
| assert isscalar(b) | |||
| return b == 0 | |||
| for i in range(3): | |||
| np.testing.assert_equal(f(x).numpy()[0], False) | |||
| np.testing.assert_equal(f(x).numpy(), False) | |||
| def test_trace_reshape(): | |||
| for symbolic in [False, True]: | |||
| set_symbolic_shape(True) | |||
| x1 = tensor(np.random.randn(2, 10, 10)) | |||
| x2 = tensor(np.random.randn(4, 10, 10)) | |||
| x3 = tensor(np.random.randn(8, 10, 10)) | |||
| @@ -359,7 +361,6 @@ def test_raise_on_trace(): | |||
| def test_trace_broadcast(): | |||
| for symbolic in [False, True]: | |||
| set_symbolic_shape(True) | |||
| x1 = tensor(np.random.randn(3, 1, 1)) | |||
| x2 = tensor(np.random.randn(1, 4, 1)) | |||
| x3 = tensor(np.random.randn(1, 1, 5)) | |||
| @@ -397,7 +398,6 @@ def test_trace_nms(): | |||
| def test_trace_valid_broadcast(): | |||
| set_symbolic_shape(True) | |||
| x1 = tensor(np.random.randn(1, 1)) | |||
| x2 = tensor(np.random.randn(1, 2)) | |||
| shape = (tensor([2]), tensor([2])) | |||
| @@ -0,0 +1,52 @@ | |||
| import numpy as np | |||
| import megengine.functional as F | |||
| from megengine import Tensor | |||
| from megengine.core._trace_option import use_symbolic_shape | |||
| def test_zero_dim(): | |||
| a = Tensor(1) | |||
| a_np = np.array(1, dtype=np.int32) | |||
| np.testing.assert_equal(a, a_np) | |||
| if use_symbolic_shape(): | |||
| np.testing.assert_equal(a.shape, np.array(a_np.shape)) | |||
| else: | |||
| np.testing.assert_equal(a.shape, a_np.shape) | |||
| def test_sum(): | |||
| a = Tensor([1, 2]) | |||
| a = a.reshape((1, 2)) | |||
| assert a.sum().ndim == 0 | |||
| assert a.sum(axis=1).ndim == 1 | |||
| def test_max(): | |||
| a = Tensor([1, 2]) | |||
| a = a.reshape((1, 2)) | |||
| assert a.max().ndim == 0 | |||
| assert a.max(axis=1).ndim == 1 | |||
| def test_reshape(): | |||
| a = Tensor(1) | |||
| a = a.reshape((1, 1)) | |||
| def test_squeeze(): | |||
| a = Tensor(1) | |||
| a = a.reshape((1, 1)) | |||
| assert F.squeeze(a).ndim == 0 | |||
| def test_elemementwise(): | |||
| a = Tensor(1.0) | |||
| assert F.exp(a).ndim == 0 | |||
| assert (a + a).ndim == 0 | |||
| assert (a + 1).ndim == 0 | |||
| def test_astype(): | |||
| a = Tensor(1.0) | |||
| assert a.astype("int32").ndim == 0 | |||