| @@ -279,8 +279,8 @@ class GradManager: | |||||
| tensor.grad = grad | tensor.grad = grad | ||||
| else: | else: | ||||
| tensor.grad += grad | tensor.grad += grad | ||||
| if tensor.isscalar() and tensor.grad is not None: | |||||
| tensor.grad.setscalar() | |||||
| if tensor._isscalar() and tensor.grad is not None: | |||||
| tensor.grad._setscalar() | |||||
| finally: | finally: | ||||
| self.release() | self.release() | ||||
| backwarding_grad_manager = cache | backwarding_grad_manager = cache | ||||
| @@ -225,7 +225,7 @@ def getitem(tensor, index): | |||||
| op = builtin.IndexingMultiAxisVec(items=items) | op = builtin.IndexingMultiAxisVec(items=items) | ||||
| (result,) = apply(op, tensor, *tensors) | (result,) = apply(op, tensor, *tensors) | ||||
| if ret_scalar: | if ret_scalar: | ||||
| result.setscalar() | |||||
| result._setscalar() | |||||
| return result | return result | ||||
| @@ -51,10 +51,10 @@ def concatenate(inputs, axis=0, *, device=None): | |||||
| def astype(x, dtype): | def astype(x, dtype): | ||||
| dtype = np.dtype(dtype) | dtype = np.dtype(dtype) | ||||
| if not is_dtype_equal(x.dtype, dtype): | if not is_dtype_equal(x.dtype, dtype): | ||||
| isscalar = x.isscalar() | |||||
| isscalar = x._isscalar() | |||||
| (x,) = apply(builtin.TypeCvt(dtype=dtype), x) | (x,) = apply(builtin.TypeCvt(dtype=dtype), x) | ||||
| if isscalar: | if isscalar: | ||||
| x.setscalar() | |||||
| x._setscalar() | |||||
| return x | return x | ||||
| @@ -98,14 +98,14 @@ def result_type(*args): | |||||
| def isscalar(x): | def isscalar(x): | ||||
| if isinstance(x, Tensor): | if isinstance(x, Tensor): | ||||
| return x.isscalar() | |||||
| return x._isscalar() | |||||
| return np.isscalar(x) | return np.isscalar(x) | ||||
| def setscalar(x): | def setscalar(x): | ||||
| if isinstance(x, Tensor): | if isinstance(x, Tensor): | ||||
| x.setscalar() | |||||
| x._setscalar() | |||||
| else: | else: | ||||
| raise NotImplementedError("Unsupport type {}".format(type(x))) | raise NotImplementedError("Unsupport type {}".format(type(x))) | ||||
| @@ -67,7 +67,7 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): | |||||
| outputs = apply(op, inp) | outputs = apply(op, inp) | ||||
| for s, x in zip(shapes, outputs): | for s, x in zip(shapes, outputs): | ||||
| if not s: | if not s: | ||||
| x.setscalar() | |||||
| x._setscalar() | |||||
| return outputs | return outputs | ||||
| @@ -12,8 +12,8 @@ from ..core.ops.builtin import InplaceAdd | |||||
| def _inplace_add_(dest, delta, alpha, beta): | def _inplace_add_(dest, delta, alpha, beta): | ||||
| isscalar = dest.isscalar() | |||||
| isscalar = dest._isscalar() | |||||
| dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) | dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) | ||||
| if isscalar: | if isscalar: | ||||
| dest.setscalar() | |||||
| dest._setscalar() | |||||
| return dest | return dest | ||||
| @@ -44,11 +44,13 @@ __all__ = [ | |||||
| "linspace", | "linspace", | ||||
| "ones", | "ones", | ||||
| "ones_like", | "ones_like", | ||||
| "repeat", | |||||
| "reshape", | "reshape", | ||||
| "split", | "split", | ||||
| "squeeze", | "squeeze", | ||||
| "stack", | "stack", | ||||
| "scatter", | "scatter", | ||||
| "tile", | |||||
| "transpose", | "transpose", | ||||
| "where", | "where", | ||||
| "zeros", | "zeros", | ||||
| @@ -987,3 +989,144 @@ def arange( | |||||
| if np.dtype(dtype) == np.int32: | if np.dtype(dtype) == np.int32: | ||||
| return result.astype(dtype) | return result.astype(dtype) | ||||
| return result | return result | ||||
| def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None): | |||||
| """ | |||||
| Repeat elements of an array. | |||||
| :param inp: input tensor. | |||||
| :param repeats: the number of repetitions for each element. | |||||
| :param axis: the axis along which to repeat values. By default, use the | |||||
| flattened input array, and return a flat output array. | |||||
| :return: output tensor. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine.functional as F | |||||
| from megengine import tensor | |||||
| x = tensor([[1, 2], [3, 4]], np.int32) | |||||
| y = F.repeat(x, 2, axis=0) | |||||
| print(y.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[1 2] | |||||
| [1 2] | |||||
| [3 4] | |||||
| [3 4]] | |||||
| """ | |||||
| if axis is None: | |||||
| inp = inp.reshape(-1) # flatten | |||||
| axis = 0 | |||||
| if inp._isscalar(): | |||||
| inp._unsetscalar() | |||||
| shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) | |||||
| # assume inp.ndim is not changed during trace | |||||
| max_axis = len(shape) - 1 | |||||
| assert axis >= 0 and axis <= max_axis | |||||
| assert repeats >= 1 | |||||
| base_shape, bcast_shape, target_shape = [], [], [] | |||||
| if axis != 0: | |||||
| target_shape.append(shape[:axis]) | |||||
| base_shape.extend([shape[: axis + 1], [1,]]) | |||||
| bcast_shape.extend([shape[: axis + 1], [repeats,]]) | |||||
| target_shape.extend( | |||||
| [shape[axis] * repeats,] | |||||
| ) | |||||
| if axis + 1 <= max_axis: | |||||
| base_shape.append(shape[axis + 1 :]) | |||||
| bcast_shape.append(shape[axis + 1 :]) | |||||
| target_shape.append(shape[axis + 1 :]) | |||||
| out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape( | |||||
| concat(target_shape) | |||||
| ) | |||||
| return out | |||||
| def _tile_one_dim(inp, rep, axis): | |||||
| shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) | |||||
| # assume inp.ndim is not changed during trace | |||||
| max_axis = len(shape) - 1 | |||||
| base_shape, bcast_shape, target_shape = [], [], [] | |||||
| if axis != 0: | |||||
| base_shape.append(shape[:axis]) | |||||
| bcast_shape.append(shape[:axis]) | |||||
| target_shape.append(shape[:axis]) | |||||
| base_shape.extend([[1,], shape[axis:]]) | |||||
| bcast_shape.extend([rep, shape[axis:]]) | |||||
| target_shape.append(shape[axis] * rep) | |||||
| if axis + 1 <= max_axis: | |||||
| target_shape.append(shape[axis + 1 :]) | |||||
| out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape( | |||||
| concat(target_shape) | |||||
| ) | |||||
| return out | |||||
| def tile(inp: Tensor, reps: Iterable[int]): | |||||
| """ | |||||
| Construct an array by repeating ``inp`` the number of times given by ``reps``. If reps has length d, | |||||
| the result will have dimension of ``max(d, inp.ndim)``. It is required that ``d >= inp.dim``. If ``inp.ndim < d``, | |||||
| ``inp`` is promoted to be ``d``-dimensional by prepending new axis. | |||||
| :param inp: input tensor. | |||||
| :param reps: The number of repetitions of inp along each axis. | |||||
| :return: output tensor. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine.functional as F | |||||
| from megengine import tensor | |||||
| x = tensor([[1, 2], [3, 4]], np.int32) | |||||
| y = F.tile(x, (2,1)) | |||||
| print(y.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[1 2] | |||||
| [3 4] | |||||
| [1 2] | |||||
| [3 4]] | |||||
| """ | |||||
| shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) | |||||
| reps = astensor1d(reps, inp, dtype="int32", device=inp.device) | |||||
| l_shape = len(shape) | |||||
| l_reps = len(reps) | |||||
| assert ( | |||||
| l_reps >= l_shape | |||||
| ), "Number of dimensions of tiled dims can not be smaller than number of dimensions of tensor" | |||||
| for i in range(l_shape): | |||||
| rep = reps[i + (l_reps - l_shape)] | |||||
| inp = _tile_one_dim(inp, rep, i) | |||||
| if l_reps > l_shape: | |||||
| shape = inp.shape | |||||
| extra = reps[:-l_shape] | |||||
| extra_ones = ones_like(extra) | |||||
| base_shape = concat([extra_ones, shape]) | |||||
| bcast_shape = concat([extra, shape]) | |||||
| target_shape = concat([extra, shape]) | |||||
| inp = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape) | |||||
| return inp | |||||
| @@ -51,10 +51,6 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
| cn = device._cn | cn = device._cn | ||||
| if isinstance(data, _Tensor): | if isinstance(data, _Tensor): | ||||
| if dtype is not None: | |||||
| logger.warning( | |||||
| "dtype does not work when creating a new Tensor with another Tensor" | |||||
| ) | |||||
| obj = _Tensor.__new__(cls, data) | obj = _Tensor.__new__(cls, data) | ||||
| else: | else: | ||||
| if isinstance(data, np.ndarray): | if isinstance(data, np.ndarray): | ||||
| @@ -557,6 +557,11 @@ void TensorWrapper::setscalar() { | |||||
| } | } | ||||
| void TensorWrapper::unsetscalar() { | |||||
| m_tensor->m_flags &= ~Tensor::Flags::SCALAR; | |||||
| } | |||||
| struct TensorWeakRef { | struct TensorWeakRef { | ||||
| std::weak_ptr<Tensor> wptr; | std::weak_ptr<Tensor> wptr; | ||||
| @@ -794,8 +799,9 @@ void init_tensor(py::module m) { | |||||
| .def_getset<&TensorWrapper::dtype>("dtype") | .def_getset<&TensorWrapper::dtype>("dtype") | ||||
| .def_getset<&TensorWrapper::device>("device") | .def_getset<&TensorWrapper::device>("device") | ||||
| .def<&TensorWrapper::reset>("_reset") | .def<&TensorWrapper::reset>("_reset") | ||||
| .def<&TensorWrapper::isscalar>("isscalar") | |||||
| .def<&TensorWrapper::setscalar>("setscalar") | |||||
| .def<&TensorWrapper::isscalar>("_isscalar") | |||||
| .def<&TensorWrapper::setscalar>("_setscalar") | |||||
| .def<&TensorWrapper::unsetscalar>("_unsetscalar") | |||||
| .def<&TensorWrapper::detach>("detach") | .def<&TensorWrapper::detach>("detach") | ||||
| .def<&TensorWrapper::_dev_tensor>("_dev_tensor") | .def<&TensorWrapper::_dev_tensor>("_dev_tensor") | ||||
| .def<&TensorWrapper::_swap_out>("_swap_out") | .def<&TensorWrapper::_swap_out>("_swap_out") | ||||
| @@ -153,6 +153,7 @@ struct TensorWrapper { | |||||
| PyObject* detach(); | PyObject* detach(); | ||||
| PyObject* isscalar(); | PyObject* isscalar(); | ||||
| void setscalar(); | void setscalar(); | ||||
| void unsetscalar(); | |||||
| PyObject* _dev_tensor(); | PyObject* _dev_tensor(); | ||||
| void _swap_in(); | void _swap_in(); | ||||
| void _swap_out(); | void _swap_out(); | ||||
| @@ -406,3 +406,53 @@ def test_copy_d2h(): | |||||
| def test_copy_d2d(): | def test_copy_d2d(): | ||||
| copy_test("gpu0", "gpu1") | copy_test("gpu0", "gpu1") | ||||
| copy_test("gpu0:0", "gpu0:1") | copy_test("gpu0:0", "gpu0:1") | ||||
| @pytest.mark.parametrize( | |||||
| "shape, repeats, axis", | |||||
| [ | |||||
| ((2,), 2, 0), | |||||
| ((2, 3, 4, 5), 3, 0), | |||||
| ((2, 3, 4, 5), 4, 3), | |||||
| ((2,), 2, None), | |||||
| ((2, 3, 4, 5), 3, None), | |||||
| ((), 1, None), | |||||
| ((), 10, None), | |||||
| ], | |||||
| ) | |||||
| def test_repeat(shape, repeats, axis): | |||||
| def repeat_func(inp): | |||||
| return F.repeat(inp=inp, repeats=repeats, axis=axis) | |||||
| if shape != (): | |||||
| cases = [ | |||||
| {"input": np.random.randn(*shape).astype("float32")}, | |||||
| ] | |||||
| else: | |||||
| cases = [{"input": np.array(1.23)}] | |||||
| opr_test( | |||||
| cases, repeat_func, ref_fn=lambda inp: np.repeat(inp, repeats, axis), | |||||
| ) | |||||
| @pytest.mark.parametrize( | |||||
| "shape, reps", | |||||
| [ | |||||
| ((2,), (2,)), | |||||
| ((2, 3, 4, 5), (1, 1, 1, 1)), | |||||
| ((2, 3, 4, 5), (1, 2, 3, 4)), | |||||
| ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), | |||||
| ], | |||||
| ) | |||||
| def test_tile(shape, reps): | |||||
| def tile_func(inp): | |||||
| return F.tile(inp=inp, reps=reps) | |||||
| cases = [ | |||||
| {"input": np.random.randn(*shape).astype("float32")}, | |||||
| ] | |||||
| opr_test( | |||||
| cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), | |||||
| ) | |||||
| @@ -7,6 +7,7 @@ | |||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import io | import io | ||||
| import itertools | |||||
| from tempfile import mkstemp | from tempfile import mkstemp | ||||
| import numpy as np | import numpy as np | ||||
| @@ -359,7 +360,7 @@ def test_trace_warp_perspective(): | |||||
| np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2])) | np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2])) | ||||
| return out | return out | ||||
| for i in range(1): | |||||
| for i in range(3): | |||||
| f(x, M) | f(x, M) | ||||