Browse Source

feat(mge/imperative): simulates scalar

GitOrigin-RevId: e81630e256
tags/v1.1.0
Megvii Engine Team 5 years ago
parent
commit
638ab52fdc
26 changed files with 214 additions and 95 deletions
  1. +2
    -0
      imperative/python/megengine/core/_trace_option.py
  2. +13
    -16
      imperative/python/megengine/core/tensor/indexing.py
  3. +2
    -1
      imperative/python/megengine/core/tensor/megbrain_graph.py
  4. +9
    -3
      imperative/python/megengine/core/tensor/raw_tensor/__init__.py
  5. +27
    -17
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  6. +29
    -7
      imperative/python/megengine/core/tensor/utils.py
  7. +6
    -2
      imperative/python/megengine/distributed/helper.py
  8. +8
    -0
      imperative/python/megengine/functional/elemwise.py
  9. +5
    -5
      imperative/python/megengine/functional/loss.py
  10. +9
    -9
      imperative/python/megengine/functional/math.py
  11. +2
    -0
      imperative/python/megengine/functional/tensor.py
  12. +1
    -1
      imperative/python/megengine/functional/utils.py
  13. +28
    -12
      imperative/python/megengine/jit/tracing.py
  14. +2
    -2
      imperative/python/megengine/quantization/observer.py
  15. +2
    -2
      imperative/python/test/integration/test_advance_indexing.py
  16. +1
    -1
      imperative/python/test/integration/test_ai.py
  17. +2
    -2
      imperative/python/test/integration/test_detach.py
  18. +1
    -1
      imperative/python/test/integration/test_hello_world.py
  19. +1
    -1
      imperative/python/test/integration/test_lr_scheduler.py
  20. +1
    -1
      imperative/python/test/integration/test_optimizer.py
  21. +1
    -1
      imperative/python/test/integration/test_save_load.py
  22. +1
    -1
      imperative/python/test/integration/test_sgd_momentum.py
  23. +1
    -1
      imperative/python/test/integration/test_trace_dump.py
  24. +2
    -3
      imperative/python/test/unit/core/test_indexing_op.py
  25. +6
    -6
      imperative/python/test/unit/test_tracing.py
  26. +52
    -0
      imperative/python/test/unit/test_zero_dim_tensor.py

+ 2
- 0
imperative/python/megengine/core/_trace_option.py View File

@@ -26,4 +26,6 @@ def set_symbolic_shape(option: bool):
""" Sets whether tensor.shape returns a tensor instead of a tuple """ Sets whether tensor.shape returns a tensor instead of a tuple
""" """
global _use_symbolic_shape global _use_symbolic_shape
_org = _use_symbolic_shape
_use_symbolic_shape = option _use_symbolic_shape = option
return _org

+ 13
- 16
imperative/python/megengine/core/tensor/indexing.py View File

@@ -14,7 +14,7 @@ from .._trace_option import use_symbolic_shape
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
from .core import TensorBase, TensorWrapperBase, apply from .core import TensorBase, TensorWrapperBase, apply
from .utils import astensor1d, make_shape_tuple
from .utils import astensor1d, isscalar, make_shape_tuple




def remove_ellipsis(tensor, tuple_val): def remove_ellipsis(tensor, tuple_val):
@@ -89,9 +89,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
if not isinstance(tuple_val, tuple): if not isinstance(tuple_val, tuple):
tuple_val = (tuple_val,) tuple_val = (tuple_val,)
ndim_indexed = 0 ndim_indexed = 0
ndim_indexed_scalar = 0
for i in tuple_val: for i in tuple_val:
if not i is Ellipsis: if not i is Ellipsis:
ndim_indexed += 1 if not hasattr(i, "ndim") else i.ndim ndim_indexed += 1 if not hasattr(i, "ndim") else i.ndim
if isscalar(i):
ndim_indexed_scalar += 1

if ndim_indexed > inp.ndim: if ndim_indexed > inp.ndim:
raise IndexError( raise IndexError(
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format(
@@ -103,15 +107,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
use_subtensor = True use_subtensor = True
inp, tuple_val = check_bool_index(inp, tuple_val) inp, tuple_val = check_bool_index(inp, tuple_val)


def is_scalar(d):
if isinstance(i, int):
return True
if type(d).__module__ == np.__name__:
return np.isscalar(d)
# if isinstance(d, (TensorBase, TensorWrapperBase)):
# return d.shape == (1,)
return False

new_axes = [] new_axes = []
tensors = [] tensors = []
items = [] items = []
@@ -134,7 +129,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
continue continue


if ( if (
not is_scalar(i)
not isscalar(i)
and not i is np.newaxis and not i is np.newaxis
and not i is Ellipsis and not i is Ellipsis
and not isinstance(i, slice) and not isinstance(i, slice)
@@ -191,7 +186,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
items.append(item) items.append(item)
if new_axes: if new_axes:
raise IndexError("newaxis is not allowed here") raise IndexError("newaxis is not allowed here")
return inp, tensors, items, use_subtensor
return inp, tensors, items, use_subtensor, ndim_indexed_scalar == inp.ndim




def try_condtake(tensor, index): def try_condtake(tensor, index):
@@ -217,11 +212,11 @@ def getitem(tensor, index):
try_result = try_condtake(tensor, index) try_result = try_condtake(tensor, index)
if len(try_result) == 2: if len(try_result) == 2:
return try_result[0] return try_result[0]
tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index)
tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index)
for v in tensors: for v in tensors:
if isinstance(v.shape, v.__class__): if isinstance(v.shape, v.__class__):
break break
if v.shape[0] == 0:
if len(v.shape) > 0 and v.shape[0] == 0:
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)(
tensor tensor
) )
@@ -231,6 +226,8 @@ def getitem(tensor, index):
else: else:
op = builtin.IndexingMultiAxisVec(items=items) op = builtin.IndexingMultiAxisVec(items=items)
(result,) = apply(op, tensor, *tensors) (result,) = apply(op, tensor, *tensors)
if ret_scalar:
result.__wrapped__._data._isscalar = True
return result return result




@@ -245,9 +242,9 @@ def setitem(tensor, index, value):
if not isinstance(value, (TensorBase, TensorWrapperBase)): if not isinstance(value, (TensorBase, TensorWrapperBase)):
op = Const(value, dtype=tensor.dtype, device=tensor.device) op = Const(value, dtype=tensor.dtype, device=tensor.device)
(value,) = op(tensor) (value,) = op(tensor)
tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index)
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index)
for v in tensors: for v in tensors:
if v.shape[0] == 0:
if len(v.shape) > 0 and v.shape[0] == 0:
return tensor return tensor
if use_subtensor: if use_subtensor:
op = builtin.Subtensor(items=items) op = builtin.Subtensor(items=items)


+ 2
- 1
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -102,8 +102,9 @@ class Graph(_imperative_rt.ComputingGraph):




class VarNode(TensorBase): class VarNode(TensorBase):
def __init__(self, node: _imperative_rt.VarNode):
def __init__(self, node: _imperative_rt.VarNode, isscalar=False):
self._node = node self._node = node
self._isscalar = isscalar
if hasattr(self.graph, "_var_cache"): if hasattr(self.graph, "_var_cache"):
self.graph._var_cache[node] = self self.graph._var_cache[node] = self




+ 9
- 3
imperative/python/megengine/core/tensor/raw_tensor/__init__.py View File

@@ -33,8 +33,9 @@ class RawTensor(TensorBase):
_del_cb = None _del_cb = None
_handle = None _handle = None


def __init__(self, handle=None):
def __init__(self, handle=None, isscalar=False):
self._handle = handle self._handle = handle
self._isscalar = isscalar
if handle is not None: if handle is not None:
if self._init_cb: if self._init_cb:
self._init_cb() self._init_cb()
@@ -49,10 +50,15 @@ class RawTensor(TensorBase):


@property @property
def shape(self): def shape(self):
if self._isscalar:
return ()
return get_shape(self._handle) return get_shape(self._handle)


def numpy(self): def numpy(self):
return get_value(self._handle)
ret = get_value(self._handle)
if self._isscalar:
ret = ret.squeeze()
return ret


def _dev_tensor(self): def _dev_tensor(self):
return _get_dev_tensor(self._handle) return _get_dev_tensor(self._handle)
@@ -102,7 +108,7 @@ def _(array: np.ndarray, dtype=None, device=None):
device = None if device is None else as_device(device).to_c() device = None if device is None else as_device(device).to_c()
if 0 in array.strides: if 0 in array.strides:
array = array.squeeze().reshape(array.shape) array = array.squeeze().reshape(array.shape)
return RawTensor(put(array, dtype=dtype, device=device))
return RawTensor(put(array, dtype=dtype, device=device), isscalar=(array.ndim == 0))




@as_raw_tensor.register(RawTensor) @as_raw_tensor.register(RawTensor)


+ 27
- 17
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -21,7 +21,9 @@ from .indexing import getitem as _getitem
from .indexing import setitem as _setitem from .indexing import setitem as _setitem
from .raw_tensor import RawTensor, as_raw_tensor from .raw_tensor import RawTensor, as_raw_tensor
from .tensor import Tensor from .tensor import Tensor
from .utils import isscalar
from .utils import make_shape_tuple as _make_shape_tuple from .utils import make_shape_tuple as _make_shape_tuple
from .utils import setscalar


_ElwMod = Elemwise.Mode _ElwMod = Elemwise.Mode


@@ -39,6 +41,13 @@ def _elwise(*args, mode):
) )
args = utils.convert_inputs(*args) args = utils.convert_inputs(*args)
(result,) = apply(op, *args) (result,) = apply(op, *args)
_isscalar = True
for i in args:
if isscalar(i) == False:
_isscalar = False
break
if _isscalar:
setscalar(result)
return result return result




@@ -153,6 +162,8 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:
param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis)) param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis))
op = builtin.AxisAddRemove(param=param) op = builtin.AxisAddRemove(param=param)
(result,) = apply(op, inp) (result,) = apply(op, inp)
if len(axis) == inp.ndim:
setscalar(result)
return result return result




@@ -189,6 +200,8 @@ def _reduce(mode):
if self.dtype == np.bool_: if self.dtype == np.bool_:
if mode in ["MIN", "MAX"]: if mode in ["MIN", "MAX"]:
result = result.astype("bool") result = result.astype("bool")
if axis is None or self.ndim == 1:
setscalar(result)
return result return result


return f return f
@@ -321,9 +334,7 @@ class ArrayMethodMixin(abc.ABC):
__complex__ = lambda self: complex(self.item()) __complex__ = lambda self: complex(self.item())


def __len__(self): def __len__(self):
shape = self.shape
if use_symbolic_shape():
shape = shape.numpy()
shape = self.__wrapped__.shape
if shape: if shape:
return int(shape[0]) return int(shape[0])
raise TypeError("ndim is 0") raise TypeError("ndim is 0")
@@ -344,18 +355,17 @@ class ArrayMethodMixin(abc.ABC):


@property @property
def ndim(self): def ndim(self):
shape = self.shape
if isinstance(shape, self.__class__):
# XXX: assume ndim is not changed during trace
ndim = shape.__wrapped__.shape[0]
return ndim
shape = self.__wrapped__.shape
if shape is None:
raise ValueError("unkown ndim")
return len(shape) return len(shape)


@property @property
def size(self): def size(self):
if use_symbolic_shape():
return self.shape.prod()
return np.prod(self.shape).item()
shape = self.shape
if shape.__class__ is tuple:
return np.prod(self.shape).item()
return shape.prod()


@property @property
def T(self): def T(self):
@@ -416,8 +426,8 @@ class ArrayMethodMixin(abc.ABC):


.. testoutput:: .. testoutput::


[2]
[10.]
2
10.


""" """
return _reduce("SUM")(self, axis, keepdims) return _reduce("SUM")(self, axis, keepdims)
@@ -444,10 +454,10 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase):


@property @property
def shape(self): def shape(self):
if use_symbolic_shape():
return apply(GetVarShape(), self)[0]
else:
return self.__wrapped__.shape
shape = self.__wrapped__.shape
if shape == () or not use_symbolic_shape():
return shape
return apply(GetVarShape(), self)[0]


@property @property
def device(self): def device(self):


+ 29
- 7
imperative/python/megengine/core/tensor/utils.py View File

@@ -133,7 +133,9 @@ def concatenate(inputs, axis=0, *, device=None):
def astype(x, dtype): def astype(x, dtype):
dtype = np.dtype(dtype) dtype = np.dtype(dtype)
if not is_equal(x.dtype, dtype): if not is_equal(x.dtype, dtype):
isscalar = x.__wrapped__._data._isscalar
(x,) = apply(builtin.TypeCvt(param=dtype), x) (x,) = apply(builtin.TypeCvt(param=dtype), x)
x.__wrapped__._data._isscalar = isscalar
return x return x




@@ -176,13 +178,29 @@ def result_type(*args):




def isscalar(x): def isscalar(x):
try:
return x.ndim == 0
except:
pass
if isinstance(x, TensorWrapperBase):
x = x.__wrapped__

if hasattr(x, "_isscalar"):
return x._isscalar
if isinstance(x, TensorBase):
return x._data._isscalar

return np.isscalar(x) return np.isscalar(x)




def setscalar(x):
if isinstance(x, TensorWrapperBase):
x = x.__wrapped__

if hasattr(x, "_isscalar"):
x._isscalar = True
elif isinstance(x, TensorBase):
x._data._isscalar = True
else:
raise NotImplementedError("Unsupport type {}".format(type(x)))


def astensor1d(x, *reference, dtype=None, device=None): def astensor1d(x, *reference, dtype=None, device=None):
""" """
Convert something to 1D tensor. Support following types Convert something to 1D tensor. Support following types
@@ -195,8 +213,8 @@ def astensor1d(x, *reference, dtype=None, device=None):
except AttributeError: except AttributeError:
pass pass
else: else:
if ndim != 1:
raise ValueError("ndim != 1: %d" % ndim)
if ndim != 0 and ndim != 1:
raise ValueError("ndim != 1 or 0, get : %d" % ndim)
if not isinstance(x, (TensorBase, TensorWrapperBase)): if not isinstance(x, (TensorBase, TensorWrapperBase)):
(x,) = Const(x, dtype=dtype, device=device)(*reference) (x,) = Const(x, dtype=dtype, device=device)(*reference)
return x return x
@@ -216,7 +234,11 @@ def astensor1d(x, *reference, dtype=None, device=None):


def _expand_int(s, i): def _expand_int(s, i):
if isinstance(i, (TensorBase, TensorWrapperBase)): if isinstance(i, (TensorBase, TensorWrapperBase)):
s += list(i.numpy())
i_np = i.numpy()
if i_np.ndim == 0:
s.append(int(i_np))
else:
s += list(i_np)
return return
if isinstance(i, Iterable): if isinstance(i, Iterable):
for ii in i: for ii in i:


+ 6
- 2
imperative/python/megengine/distributed/helper.py View File

@@ -63,8 +63,12 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list):
""" """
op = ParamPackSplit() op = ParamPackSplit()
op.offsets = offsets op.offsets = offsets
op.shapes = shapes
return apply(op, inp)
op.shapes = [s or (1,) for s in shapes]
outputs = apply(op, inp)
for s, x in zip(shapes, outputs):
if not s:
x._isscalar = True
return outputs




def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list): def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list):


+ 8
- 0
imperative/python/megengine/functional/elemwise.py View File

@@ -13,6 +13,7 @@ from ..core.ops import builtin
from ..core.ops.builtin import Elemwise from ..core.ops.builtin import Elemwise
from ..core.tensor import megbrain_graph, utils from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import apply from ..core.tensor.core import apply
from ..core.tensor.utils import isscalar, setscalar
from ..device import get_default_device from ..device import get_default_device
from ..jit.tracing import is_tracing from ..jit.tracing import is_tracing
from ..tensor import Tensor from ..tensor import Tensor
@@ -105,7 +106,14 @@ def _elwise(*args, mode):
args = utils.convert_inputs(*args) args = utils.convert_inputs(*args)
if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"): if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"):
args = tuple(map(lambda x: x.astype("float32"), args)) args = tuple(map(lambda x: x.astype("float32"), args))
_isscalar = True
for i in args:
if isscalar(i) == False:
_isscalar = False
break
(result,) = apply(op, *args) (result,) = apply(op, *args)
if _isscalar:
setscalar(result)
return result return result






+ 5
- 5
imperative/python/megengine/functional/loss.py View File

@@ -63,7 +63,7 @@ def l1_loss(pred: Tensor, label: Tensor) -> Tensor:


.. testoutput:: .. testoutput::


[2.75]
2.75


""" """
diff = pred - label diff = pred - label
@@ -115,7 +115,7 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor:


.. testoutput:: .. testoutput::


[9.75]
9.75


""" """
diff = pred - label diff = pred - label
@@ -170,7 +170,7 @@ def cross_entropy(


.. testoutput:: .. testoutput::


[0.6931]
0.6931


""" """
n0 = pred.ndim n0 = pred.ndim
@@ -226,7 +226,7 @@ def binary_cross_entropy(


.. testoutput:: .. testoutput::


[0.6931]
0.6931


""" """
if not with_logits: if not with_logits:
@@ -265,7 +265,7 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:


.. testoutput:: .. testoutput::


[1.5]
1.5


""" """
assert norm in ["L1", "L2"], "norm must be L1 or L2" assert norm in ["L1", "L2"], "norm must be L1 or L2"


+ 9
- 9
imperative/python/megengine/functional/math.py View File

@@ -155,7 +155,7 @@ def sum(


.. testoutput:: .. testoutput::


[21]
21


""" """
return inp.sum(axis=axis, keepdims=keepdims) return inp.sum(axis=axis, keepdims=keepdims)
@@ -189,7 +189,7 @@ def prod(


.. testoutput:: .. testoutput::


[720]
720


""" """
return inp.prod(axis=axis, keepdims=keepdims) return inp.prod(axis=axis, keepdims=keepdims)
@@ -226,7 +226,7 @@ def mean(


.. testoutput:: .. testoutput::


[3.5]
3.5


""" """
return inp.mean(axis=axis, keepdims=keepdims) return inp.mean(axis=axis, keepdims=keepdims)
@@ -263,7 +263,7 @@ def var(


.. testoutput:: .. testoutput::


[2.9167]
2.9167
""" """
if axis is None: if axis is None:
m = mean(inp, axis=axis, keepdims=False) m = mean(inp, axis=axis, keepdims=False)
@@ -340,7 +340,7 @@ def min(


.. testoutput:: .. testoutput::


[1]
1


""" """
return inp.min(axis=axis, keepdims=keepdims) return inp.min(axis=axis, keepdims=keepdims)
@@ -377,7 +377,7 @@ def max(


.. testoutput:: .. testoutput::


[6]
6


""" """
return inp.max(axis=axis, keepdims=keepdims) return inp.max(axis=axis, keepdims=keepdims)
@@ -412,7 +412,7 @@ def norm(


.. testoutput:: .. testoutput::


[4.3589]
4.3589


""" """
if axis is None: if axis is None:
@@ -460,7 +460,7 @@ def argmin(


.. testoutput:: .. testoutput::


[0]
0


""" """
if isinstance(axis, collections.abc.Iterable): if isinstance(axis, collections.abc.Iterable):
@@ -519,7 +519,7 @@ def argmax(


.. testoutput:: .. testoutput::


[5]
5


""" """
if isinstance(axis, collections.abc.Iterable): if isinstance(axis, collections.abc.Iterable):


+ 2
- 0
imperative/python/megengine/functional/tensor.py View File

@@ -111,6 +111,8 @@ def full(shape, value, dtype="float32", device=None):
(x,) = Const(value, dtype=dtype, device=device)( (x,) = Const(value, dtype=dtype, device=device)(
Tensor(value, dtype=dtype, device=device) Tensor(value, dtype=dtype, device=device)
) )
if len(shape) == 0: # scalar
return x
return broadcast_to(x, shape) return broadcast_to(x, shape)






+ 1
- 1
imperative/python/megengine/functional/utils.py View File

@@ -53,7 +53,7 @@ def topk_accuracy(


.. testoutput:: .. testoutput::


[0.] [0.375]
0.0 0.375
""" """
if isinstance(topk, int): if isinstance(topk, int):
topk = (topk,) topk = (topk,)


+ 28
- 12
imperative/python/megengine/jit/tracing.py View File

@@ -168,8 +168,6 @@ class trace:
self._output_bindings = None self._output_bindings = None
self._output_names = None self._output_names = None


set_symbolic_shape(self._symbolic_shape)

def _new_handle(self): def _new_handle(self):
handle = len(self._tinfo) handle = len(self._tinfo)
info = TensorInfo() info = TensorInfo()
@@ -368,6 +366,7 @@ class trace:
interrupted = False interrupted = False


def do_enter(): def do_enter():
self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape)
self._set_active(True) self._set_active(True)
if self._untraced: if self._untraced:
self._init_trace(self._symbolic) self._init_trace(self._symbolic)
@@ -423,6 +422,8 @@ class trace:
apply.disable(apply_compiled_mode) apply.disable(apply_compiled_mode)
apply.disable(apply_const_compiled_mode) apply.disable(apply_const_compiled_mode)
self._set_active(False) self._set_active(False)
# Restore global variable
set_symbolic_shape(self._save_symbolic_shape)


def do_exit(): def do_exit():
if not self._untraced and self._pc != len(self._seq): if not self._untraced and self._pc != len(self._seq):
@@ -498,7 +499,7 @@ class trace:
opnode = info.data_setter = G.InputNode( opnode = info.data_setter = G.InputNode(
device=info.device, device=info.device,
dtype=info.dtype, dtype=info.dtype,
shape=info.shape,
shape=info.shape or (1,),
graph=graph, graph=graph,
use_static_shape=_input_node_use_static_shape(), use_static_shape=_input_node_use_static_shape(),
) )
@@ -544,7 +545,7 @@ class trace:
*links, *links,
device=info.device, device=info.device,
dtype=info.dtype, dtype=info.dtype,
shape=info.shape,
shape=info.shape or (1,),
graph=graph, graph=graph,
use_static_shape=_input_node_use_static_shape(), use_static_shape=_input_node_use_static_shape(),
) )
@@ -719,13 +720,13 @@ class trace:
h2v[h] = graph.make_h2d( h2v[h] = graph.make_h2d(
dtype=info.dtype, dtype=info.dtype,
device=dumped_device, device=dumped_device,
shape=info.shape,
shape=info.shape or (1,),
name=arg_names[i] if arg_names else None, name=arg_names[i] if arg_names else None,
) )
for k, h in self._kwarg_bindings.items(): for k, h in self._kwarg_bindings.items():
info = self._tinfo[h] info = self._tinfo[h]
h2v[h] = graph.make_h2d( h2v[h] = graph.make_h2d(
dtype=info.dtype, device=dumped_device, shape=info.shape, name=k
dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k
) )


for op, ihandles, ohandles in self._seq: for op, ihandles, ohandles in self._seq:
@@ -919,6 +920,7 @@ class CompiledTensorProxy(RawTensor):


def __init__(self, handle): def __init__(self, handle):
self.__handle = handle self.__handle = handle
self._isscalar = False
self.__info = active_trace._tinfo[handle] self.__info = active_trace._tinfo[handle]
self.__shape = None self.__shape = None
self.__data = None self.__data = None
@@ -934,6 +936,8 @@ class CompiledTensorProxy(RawTensor):


@property @property
def shape(self): def shape(self):
if self._isscalar:
return ()
if self.__shape is None: if self.__shape is None:
if self.__info.shape_read: if self.__info.shape_read:
self.__shape = self.__info.shape_reader.get_value().shape self.__shape = self.__info.shape_reader.get_value().shape
@@ -951,6 +955,8 @@ class CompiledTensorProxy(RawTensor):
self.__value = self._dev_tensor().numpy() self.__value = self._dev_tensor().numpy()
else: else:
raise TraceMismatchError("value of this tensor is not read in trace") raise TraceMismatchError("value of this tensor is not read in trace")
if self._isscalar:
self.__value = self.__value.squeeze()
return self.__value return self.__value


def _dev_tensor(self): def _dev_tensor(self):
@@ -970,9 +976,10 @@ class CompiledTensorProxy(RawTensor):




class LazyEvalTensor(RawTensor): class LazyEvalTensor(RawTensor):
def __init__(self, varnode):
super(LazyEvalTensor, self).__init__()
def __init__(self, varnode, isscalar=False):
super().__init__()
self.__varnode = varnode self.__varnode = varnode
self._isscalar = isscalar


@property @property
def dtype(self): def dtype(self):
@@ -984,10 +991,15 @@ class LazyEvalTensor(RawTensor):


@property @property
def shape(self): def shape(self):
if self._isscalar:
return ()
return self.__varnode.shape return self.__varnode.shape


def numpy(self): def numpy(self):
return self.__varnode.value
ret = self.__varnode.value
if self._isscalar:
ret = ret.squeeze()
return ret


def _dev_tensor(self): def _dev_tensor(self):
raise RuntimeError("cannot access data during symbolic tracing") raise RuntimeError("cannot access data during symbolic tracing")
@@ -1041,10 +1053,12 @@ class TracedLazyTensor(TraceMixin, LazyEvalTensor):


def assign_raw_tensor(lhs, rhs): def assign_raw_tensor(lhs, rhs):
handle = rhs._handle handle = rhs._handle
# Keep isscalar of lhs
isscalar = lhs._isscalar
rhs.__dict__.clear() rhs.__dict__.clear()
lhs.__dict__.clear() lhs.__dict__.clear()
lhs.__class__ = RawTensor lhs.__class__ = RawTensor
lhs.__init__(handle)
lhs.__init__(handle, isscalar=isscalar)




# this hook turns RawTensor into LazyEvalTensor # this hook turns RawTensor into LazyEvalTensor
@@ -1060,7 +1074,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
data_setter = G.InputNode( data_setter = G.InputNode(
device=x.device, device=x.device,
dtype=x.dtype, dtype=x.dtype,
shape=x.shape,
shape=x.shape or (1,),
graph=graph, graph=graph,
use_static_shape=True, use_static_shape=True,
) )
@@ -1091,7 +1105,9 @@ apply.disable(apply_symbolic_mode)
@apply.register() @apply.register()
def apply_const_symbolic_mode(op: Const, *args: RawTensor): def apply_const_symbolic_mode(op: Const, *args: RawTensor):
graph = active_trace._lazy_eval_graph graph = active_trace._lazy_eval_graph
ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device))
ret = LazyEvalTensor(
graph.make_const(op.value, dtype=op.dtype, device=op.device), isscalar=True
)
active_trace._lazy_eval_tensors.add(ret) active_trace._lazy_eval_tensors.add(ret)
return (ret,) return (ret,)




+ 2
- 2
imperative/python/megengine/quantization/observer.py View File

@@ -46,9 +46,9 @@ class Observer(Module):


def get_dtype(self): def get_dtype(self):
q_dict = self.get_qparams() q_dict = self.get_qparams()
numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0]
numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()
numpy_zero_point = ( numpy_zero_point = (
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0]
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()
) )
return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point) return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)




+ 2
- 2
imperative/python/test/integration/test_advance_indexing.py View File

@@ -18,7 +18,7 @@ from megengine.module import Module
class Simple(Module): class Simple(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.a = Parameter(1.0, dtype=np.float32)
self.a = Parameter([1.0], dtype=np.float32)


def forward(self, x, y): def forward(self, x, y):
x = x[y] * self.a x = x[y] * self.a
@@ -28,7 +28,7 @@ class Simple(Module):
class Simple2(Module): class Simple2(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.a = Parameter(1.0, dtype=np.float32)
self.a = Parameter([1.0], dtype=np.float32)


def forward(self, x): def forward(self, x):
x = x[1, ..., :, 0:4:2, 0:2] * self.a x = x[1, ..., :, 0:4:2, 0:2] * self.a


+ 1
- 1
imperative/python/test/integration/test_ai.py View File

@@ -18,7 +18,7 @@ from megengine.module import Module
class Simple(Module): class Simple(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.a = Parameter(1.0, dtype=np.float32)
self.a = Parameter([1.0], dtype=np.float32)


def forward(self, x): def forward(self, x):
x = x[:, 0] * self.a x = x[:, 0] * self.a


+ 2
- 2
imperative/python/test/integration/test_detach.py View File

@@ -18,8 +18,8 @@ from megengine.module import Module
class Simple(Module): class Simple(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.a = Parameter(1.0, dtype=np.float32)
self.b = Parameter(1.0, dtype=np.float32)
self.a = Parameter([1.0], dtype=np.float32)
self.b = Parameter([1.0], dtype=np.float32)


def forward(self, x): def forward(self, x):
x = x * self.a x = x * self.a


+ 1
- 1
imperative/python/test/integration/test_hello_world.py View File

@@ -21,7 +21,7 @@ from megengine.module import Module
class Simple(Module): class Simple(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.a = Parameter(1.23, dtype=np.float32)
self.a = Parameter([1.23], dtype=np.float32)


def forward(self, x): def forward(self, x):
x = x * self.a x = x * self.a


+ 1
- 1
imperative/python/test/integration/test_lr_scheduler.py View File

@@ -18,7 +18,7 @@ from megengine.optimizer import SGD, MultiStepLR
class Simple(Module): class Simple(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.a = Parameter(1.23, dtype=np.float32)
self.a = Parameter([1.23], dtype=np.float32)


def forward(self, x): def forward(self, x):
x = x * self.a x = x * self.a


+ 1
- 1
imperative/python/test/integration/test_optimizer.py View File

@@ -32,7 +32,7 @@ class MLP(Module):
class Simple(Module): class Simple(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.a = Parameter(1.23, dtype=np.float32)
self.a = Parameter([1.23], dtype=np.float32)


def forward(self, x): def forward(self, x):
x = x * self.a x = x * self.a


+ 1
- 1
imperative/python/test/integration/test_save_load.py View File

@@ -11,7 +11,7 @@ from megengine.module import Module
class Simple(Module): class Simple(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.a = Parameter(1.23, dtype=np.float32)
self.a = Parameter([1.23], dtype=np.float32)


def forward(self, x): def forward(self, x):
x = x * self.a x = x * self.a


+ 1
- 1
imperative/python/test/integration/test_sgd_momentum.py View File

@@ -19,7 +19,7 @@ from megengine.module import Module
class Simple(Module): class Simple(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.a = Parameter(1.23, dtype=np.float32)
self.a = Parameter([1.23], dtype=np.float32)


def forward(self, x): def forward(self, x):
x = x * self.a x = x * self.a


+ 1
- 1
imperative/python/test/integration/test_trace_dump.py View File

@@ -107,7 +107,7 @@ def test_xornet_trace_dump():
if step % 50 == 0: if step % 50 == 0:
minibatch = next(val_dataset) minibatch = next(val_dataset)
_, loss = val_fun(data, label) _, loss = val_fun(data, label)
loss = loss.numpy()[0]
loss = loss.numpy()
val_loss.append((step, loss)) val_loss.append((step, loss))
print("Step: {} loss={}".format(step, loss)) print("Step: {} loss={}".format(step, loss))
opt.step() opt.step()


+ 2
- 3
imperative/python/test/unit/core/test_indexing_op.py View File

@@ -449,7 +449,7 @@ def test_advance_indexing_high_level():
y = np.array([1, 2]) y = np.array([1, 2])
yy = Tensor(y) yy = Tensor(y)
np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy())
# np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) # FIXME
np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy())
np.testing.assert_equal(x[:, y], xx[:, y].numpy()) np.testing.assert_equal(x[:, y], xx[:, y].numpy())
np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) np.testing.assert_equal(x[:, y], xx[:, yy].numpy())


@@ -469,10 +469,9 @@ def test_advance_indexing_high_level():
y = np.array([1]) y = np.array([1])
yy = Tensor(y) yy = Tensor(y)
np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy())
# np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) # FIXME
np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy())
np.testing.assert_equal(x[:, y], xx[:, y].numpy()) np.testing.assert_equal(x[:, y], xx[:, y].numpy())


# XXX: no way to tell whether yy is scalar or ndim=1 array
np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) np.testing.assert_equal(x[:, y], xx[:, yy].numpy())


x = np.arange(9).reshape(3, 3).astype("int32") x = np.arange(9).reshape(3, 3).astype("int32")


+ 6
- 6
imperative/python/test/unit/test_tracing.py View File

@@ -21,6 +21,7 @@ from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.core import apply from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import as_raw_tensor from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.core.tensor.utils import isscalar
from megengine.functional import exp, log from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace from megengine.jit import exclude_from_trace, trace
from megengine.random import normal, uniform from megengine.random import normal, uniform
@@ -263,20 +264,21 @@ def test_optimize_for_inference_broadcast():




def test_trace_cvt_bool(): def test_trace_cvt_bool():
set_symbolic_shape(True)
x = tensor([0], dtype=np.int32) x = tensor([0], dtype=np.int32)


@trace(symbolic=True) @trace(symbolic=True)
def f(x): def f(x):
return x.shape[0] == 0
a = x.shape
b = a[0]
assert isscalar(b)
return b == 0


for i in range(3): for i in range(3):
np.testing.assert_equal(f(x).numpy()[0], False)
np.testing.assert_equal(f(x).numpy(), False)




def test_trace_reshape(): def test_trace_reshape():
for symbolic in [False, True]: for symbolic in [False, True]:
set_symbolic_shape(True)
x1 = tensor(np.random.randn(2, 10, 10)) x1 = tensor(np.random.randn(2, 10, 10))
x2 = tensor(np.random.randn(4, 10, 10)) x2 = tensor(np.random.randn(4, 10, 10))
x3 = tensor(np.random.randn(8, 10, 10)) x3 = tensor(np.random.randn(8, 10, 10))
@@ -359,7 +361,6 @@ def test_raise_on_trace():


def test_trace_broadcast(): def test_trace_broadcast():
for symbolic in [False, True]: for symbolic in [False, True]:
set_symbolic_shape(True)
x1 = tensor(np.random.randn(3, 1, 1)) x1 = tensor(np.random.randn(3, 1, 1))
x2 = tensor(np.random.randn(1, 4, 1)) x2 = tensor(np.random.randn(1, 4, 1))
x3 = tensor(np.random.randn(1, 1, 5)) x3 = tensor(np.random.randn(1, 1, 5))
@@ -397,7 +398,6 @@ def test_trace_nms():




def test_trace_valid_broadcast(): def test_trace_valid_broadcast():
set_symbolic_shape(True)
x1 = tensor(np.random.randn(1, 1)) x1 = tensor(np.random.randn(1, 1))
x2 = tensor(np.random.randn(1, 2)) x2 = tensor(np.random.randn(1, 2))
shape = (tensor([2]), tensor([2])) shape = (tensor([2]), tensor([2]))


+ 52
- 0
imperative/python/test/unit/test_zero_dim_tensor.py View File

@@ -0,0 +1,52 @@
import numpy as np

import megengine.functional as F
from megengine import Tensor
from megengine.core._trace_option import use_symbolic_shape


def test_zero_dim():
a = Tensor(1)
a_np = np.array(1, dtype=np.int32)
np.testing.assert_equal(a, a_np)
if use_symbolic_shape():
np.testing.assert_equal(a.shape, np.array(a_np.shape))
else:
np.testing.assert_equal(a.shape, a_np.shape)


def test_sum():
a = Tensor([1, 2])
a = a.reshape((1, 2))
assert a.sum().ndim == 0
assert a.sum(axis=1).ndim == 1


def test_max():
a = Tensor([1, 2])
a = a.reshape((1, 2))
assert a.max().ndim == 0
assert a.max(axis=1).ndim == 1


def test_reshape():
a = Tensor(1)
a = a.reshape((1, 1))


def test_squeeze():
a = Tensor(1)
a = a.reshape((1, 1))
assert F.squeeze(a).ndim == 0


def test_elemementwise():
a = Tensor(1.0)
assert F.exp(a).ndim == 0
assert (a + a).ndim == 0
assert (a + 1).ndim == 0


def test_astype():
a = Tensor(1.0)
assert a.astype("int32").ndim == 0

Loading…
Cancel
Save