| @@ -100,6 +100,8 @@ class GradManager: | |||||
| :param ys: outputs of forward operators, e.g., the loss tensor | :param ys: outputs of forward operators, e.g., the loss tensor | ||||
| :param dys: derivatives of ys | :param dys: derivatives of ys | ||||
| """ | """ | ||||
| from ..functional import ones_like | |||||
| global backwarding_grad_manager | global backwarding_grad_manager | ||||
| cache = backwarding_grad_manager | cache = backwarding_grad_manager | ||||
| backwarding_grad_manager = self | backwarding_grad_manager = self | ||||
| @@ -113,7 +115,7 @@ class GradManager: | |||||
| if not isinstance(ys, (tuple, list)): | if not isinstance(ys, (tuple, list)): | ||||
| ys = [ys] | ys = [ys] | ||||
| if dys is None: | if dys is None: | ||||
| dys = [tensor(1.0).broadcast(y.shape) for y in ys] | |||||
| dys = [ones_like(y) for y in ys] | |||||
| if not isinstance(dys, (tuple, list)): | if not isinstance(dys, (tuple, list)): | ||||
| dys = [dys] | dys = [dys] | ||||
| try: | try: | ||||
| @@ -160,7 +160,7 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): | |||||
| def make_grad(grad_op, dy): | def make_grad(grad_op, dy): | ||||
| grad = ( | grad = ( | ||||
| TensorWrapper(0, dtype=dy.dtype, device=dy.device) | TensorWrapper(0, dtype=dy.dtype, device=dy.device) | ||||
| .broadcast(TensorWrapper(input_shape)) | |||||
| ._broadcast(TensorWrapper(input_shape)) | |||||
| .__wrapped__ | .__wrapped__ | ||||
| ) | ) | ||||
| (dx,) = apply(grad_op, grad, dy, *params) | (dx,) = apply(grad_op, grad, dy, *params) | ||||
| @@ -186,7 +186,7 @@ def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad): | |||||
| def make_grad(grad_op, dy): | def make_grad(grad_op, dy): | ||||
| grad = ( | grad = ( | ||||
| TensorWrapper(0, dtype=dy.dtype, device=dy.device) | TensorWrapper(0, dtype=dy.dtype, device=dy.device) | ||||
| .broadcast(TensorWrapper(input_shape)) | |||||
| ._broadcast(TensorWrapper(input_shape)) | |||||
| .__wrapped__ | .__wrapped__ | ||||
| ) | ) | ||||
| (dx,) = apply(grad_op, grad, dy, *params) | (dx,) = apply(grad_op, grad, dy, *params) | ||||
| @@ -267,7 +267,7 @@ def setitem(tensor, index, value): | |||||
| value.shape, tmp_result.shape | value.shape, tmp_result.shape | ||||
| ) | ) | ||||
| ) | ) | ||||
| value = value.broadcast(tmp_result.shape) | |||||
| value = value._broadcast(tmp_result.shape) | |||||
| if use_subtensor: | if use_subtensor: | ||||
| op = builtin.SetSubtensor(items=items) | op = builtin.SetSubtensor(items=items) | ||||
| else: | else: | ||||
| @@ -396,7 +396,8 @@ class ArrayMethodMixin(abc.ABC): | |||||
| def reshape(self, *args): | def reshape(self, *args): | ||||
| return _reshape(self, _expand_args(args)) | return _reshape(self, _expand_args(args)) | ||||
| def broadcast(self, *args): | |||||
| # FIXME: remove this method | |||||
| def _broadcast(self, *args): | |||||
| return _broadcast(self, _expand_args(args)) | return _broadcast(self, _expand_args(args)) | ||||
| def transpose(self, *args): | def transpose(self, *args): | ||||
| @@ -23,7 +23,16 @@ from .debug_param import get_conv_execution_strategy | |||||
| from .distributed import all_reduce_sum | from .distributed import all_reduce_sum | ||||
| from .elemwise import exp, floor, log, log1p, maximum, minimum, relu | from .elemwise import exp, floor, log, log1p, maximum, minimum, relu | ||||
| from .math import argsort, max, sum | from .math import argsort, max, sum | ||||
| from .tensor import broadcast, concat, expand_dims, full, ones, reshape, squeeze, zeros | |||||
| from .tensor import ( | |||||
| broadcast_to, | |||||
| concat, | |||||
| expand_dims, | |||||
| full, | |||||
| ones, | |||||
| reshape, | |||||
| squeeze, | |||||
| zeros, | |||||
| ) | |||||
| from .types import _pair, _pair_nonzero | from .types import _pair, _pair_nonzero | ||||
| __all__ = [ | __all__ = [ | ||||
| @@ -635,7 +644,7 @@ def batch_norm2d( | |||||
| def full_value(value): | def full_value(value): | ||||
| C = inp.shape[1] | C = inp.shape[1] | ||||
| (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | ||||
| return broadcast(x, [1, C, 1, 1]) | |||||
| return broadcast_to(x, [1, C, 1, 1]) | |||||
| def expand_or_full(x, value): | def expand_or_full(x, value): | ||||
| if x is None: | if x is None: | ||||
| @@ -754,7 +763,7 @@ def sync_batch_norm( | |||||
| if is_distributed(): | if is_distributed(): | ||||
| # reduce all nodes' data to calculate mean and variance | # reduce all nodes' data to calculate mean and variance | ||||
| reduce_size = broadcast(Tensor(reduce_size, dtype=_dtype), [1] * _ndim) | |||||
| reduce_size = broadcast_to(Tensor(reduce_size, dtype=_dtype), [1] * _ndim) | |||||
| stat = concat( | stat = concat( | ||||
| [reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1 | [reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1 | ||||
| ) | ) | ||||
| @@ -968,10 +977,10 @@ def matmul( | |||||
| if dim1 != dim2: | if dim1 != dim2: | ||||
| if dim1 < dim2: | if dim1 < dim2: | ||||
| shape1 = shape2[: dim2 - dim1] + shape1 | shape1 = shape2[: dim2 - dim1] + shape1 | ||||
| inp1 = inp1.broadcast(*shape1) | |||||
| inp1 = broadcast_to(inp1, shape1) | |||||
| else: | else: | ||||
| shape2 = shape1[: dim1 - dim2] + shape2 | shape2 = shape1[: dim1 - dim2] + shape2 | ||||
| inp2 = inp2.broadcast(*shape2) | |||||
| inp2 = broadcast_to(inp2, shape2) | |||||
| reshaped_batch_size = 1 | reshaped_batch_size = 1 | ||||
| for i in shape1[:-2]: | for i in shape1[:-2]: | ||||
| reshaped_batch_size *= i | reshaped_batch_size *= i | ||||
| @@ -986,9 +995,9 @@ def matmul( | |||||
| shp = shape1[:-1] + shape2[-1:] | shp = shape1[:-1] + shape2[-1:] | ||||
| elif dim1 == 3 or dim2 == 3: | elif dim1 == 3 or dim2 == 3: | ||||
| if dim2 < 3: | if dim2 < 3: | ||||
| inp2 = inp2.broadcast(*(inp1.shape[:1] + inp2.shape)) | |||||
| inp2 = broadcast_to(inp2, inp1.shape[:1] + inp2.shape) | |||||
| elif dim1 < 3: | elif dim1 < 3: | ||||
| inp1 = inp1.broadcast(*(inp2.shape[:1] + inp1.shape)) | |||||
| inp1 = broadcast_to(inp1, inp2.shape[:1] + inp1.shape) | |||||
| op = builtin.BatchedMatrixMul( | op = builtin.BatchedMatrixMul( | ||||
| transposeA=transpose_a, | transposeA=transpose_a, | ||||
| transposeB=transpose_b, | transposeB=transpose_b, | ||||
| @@ -1205,7 +1214,7 @@ def interpolate( | |||||
| [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], | [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], | ||||
| axis=0, | axis=0, | ||||
| ).reshape(1, 3, 3) | ).reshape(1, 3, 3) | ||||
| weight = broadcast(weight, (inp.shape[0], 3, 3)) | |||||
| weight = broadcast_to(weight, (inp.shape[0], 3, 3)) | |||||
| else: | else: | ||||
| hscale = 1.0 * ih / oh | hscale = 1.0 * ih / oh | ||||
| wscale = 1.0 * iw / ow | wscale = 1.0 * iw / ow | ||||
| @@ -1221,7 +1230,7 @@ def interpolate( | |||||
| [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], | [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], | ||||
| axis=0, | axis=0, | ||||
| ).reshape(1, 3, 3) | ).reshape(1, 3, 3) | ||||
| weight = broadcast(weight, (inp.shape[0], 3, 3)) | |||||
| weight = broadcast_to(weight, (inp.shape[0], 3, 3)) | |||||
| weight = weight.astype("float32") | weight = weight.astype("float32") | ||||
| ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR") | ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR") | ||||
| @@ -19,7 +19,7 @@ from ..core.ops import builtin | |||||
| from ..core.ops._internal import param_defs as P | from ..core.ops._internal import param_defs as P | ||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | ||||
| from ..core.tensor.tensor_wrapper import _remove_axis | |||||
| from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis | |||||
| from ..core.tensor.utils import ( | from ..core.tensor.utils import ( | ||||
| astensor1d, | astensor1d, | ||||
| convert_inputs, | convert_inputs, | ||||
| @@ -33,7 +33,7 @@ from .elemwise import ceil | |||||
| __all__ = [ | __all__ = [ | ||||
| "arange", | "arange", | ||||
| "broadcast", | |||||
| "broadcast_to", | |||||
| "concat", | "concat", | ||||
| "cond_take", | "cond_take", | ||||
| "expand_dims", | "expand_dims", | ||||
| @@ -104,7 +104,7 @@ def full(shape, value, dtype="float32", device=None): | |||||
| (x,) = Const(value, dtype=dtype, device=device)( | (x,) = Const(value, dtype=dtype, device=device)( | ||||
| Tensor(value, dtype=dtype, device=device) | Tensor(value, dtype=dtype, device=device) | ||||
| ) | ) | ||||
| return broadcast(x, shape) | |||||
| return broadcast_to(x, shape) | |||||
| def ones(shape, dtype="float32", device=None): | def ones(shape, dtype="float32", device=None): | ||||
| @@ -192,7 +192,7 @@ def identity(inp: Tensor) -> Tensor: | |||||
| return output | return output | ||||
| def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||||
| def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||||
| """ | """ | ||||
| Broadcasts a tensor to given shape. | Broadcasts a tensor to given shape. | ||||
| @@ -209,7 +209,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||||
| import megengine.functional as F | import megengine.functional as F | ||||
| data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | ||||
| out = F.broadcast(data, (4, 2, 3)) | |||||
| out = F.broadcast_to(data, (4, 2, 3)) | |||||
| print(out.numpy()) | print(out.numpy()) | ||||
| Outputs: | Outputs: | ||||
| @@ -229,7 +229,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||||
| [3. 4. 5.]]] | [3. 4. 5.]]] | ||||
| """ | """ | ||||
| return inp.broadcast(shape) | |||||
| return _broadcast(inp, shape) | |||||
| def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | ||||
| @@ -395,8 +395,7 @@ def _get_idx(index, axis): | |||||
| 0, index.shape[i] - 1, index.shape[i], device=index.device, | 0, index.shape[i] - 1, index.shape[i], device=index.device, | ||||
| ) | ) | ||||
| arange = ( | arange = ( | ||||
| arange.reshape(*shape) | |||||
| .broadcast(index.shape) | |||||
| broadcast_to(arange.reshape(*shape), index.shape) | |||||
| .reshape(-1) | .reshape(-1) | ||||
| .astype(np.int32) | .astype(np.int32) | ||||
| ) | ) | ||||
| @@ -15,7 +15,7 @@ from ..core.ops.builtin import Copy | |||||
| from ..core.tensor import Tensor | from ..core.tensor import Tensor | ||||
| from ..core.tensor.core import apply | from ..core.tensor.core import apply | ||||
| from .math import topk as _topk | from .math import topk as _topk | ||||
| from .tensor import transpose as _transpose | |||||
| from .tensor import broadcast_to, transpose | |||||
| def accuracy( | def accuracy( | ||||
| @@ -54,8 +54,8 @@ def accuracy( | |||||
| _, pred = _topk(logits, k=max(topk), descending=True) | _, pred = _topk(logits, k=max(topk), descending=True) | ||||
| accs = [] | accs = [] | ||||
| for k in topk: | for k in topk: | ||||
| correct = pred[:, :k].detach() == _transpose(target, (0, "x")).broadcast( | |||||
| target.shape[0], k | |||||
| correct = pred[:, :k].detach() == broadcast_to( | |||||
| transpose(target, (0, "x")), (target.shape[0], k) | |||||
| ) | ) | ||||
| accs.append(correct.astype(np.float32).sum() / target.shape[0]) | accs.append(correct.astype(np.float32).sum() / target.shape[0]) | ||||
| if len(topk) == 1: # type: ignore[arg-type] | if len(topk) == 1: # type: ignore[arg-type] | ||||
| @@ -319,7 +319,7 @@ def test_Broadcast(): | |||||
| x = TensorWrapper(x_np) | x = TensorWrapper(x_np) | ||||
| grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
| y = F.broadcast(x, (3, 3, 10)) | |||||
| y = F.broadcast_to(x, (3, 3, 10)) | |||||
| grad(y, F.ones_like(y)) | grad(y, F.ones_like(y)) | ||||
| np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) | np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) | ||||
| @@ -251,17 +251,17 @@ def test_broadcast(): | |||||
| {"input": [data1, output1_shape], "output": output1_shape}, | {"input": [data1, output1_shape], "output": output1_shape}, | ||||
| {"input": [data2, output2_shape], "output": output2_shape}, | {"input": [data2, output2_shape], "output": output2_shape}, | ||||
| ] | ] | ||||
| opr_test(cases, F.broadcast, compare_fn=compare_fn) | |||||
| opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | |||||
| x = F.ones((2, 1, 3)) | x = F.ones((2, 1, 3)) | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| F.broadcast(x, (2, 3, 4)) | |||||
| F.broadcast_to(x, (2, 3, 4)) | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| F.broadcast(x, (4, 1, 3)) | |||||
| F.broadcast_to(x, (4, 1, 3)) | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| F.broadcast(x, (1, 3)) | |||||
| F.broadcast_to(x, (1, 3)) | |||||
| def test_utils_astensor1d(): | def test_utils_astensor1d(): | ||||
| @@ -351,7 +351,7 @@ def test_trace_broadcast(): | |||||
| @trace(symbolic=symbolic, capture_as_const=True) | @trace(symbolic=symbolic, capture_as_const=True) | ||||
| def f(x): | def f(x): | ||||
| y = x.broadcast((3, 4, 5)) | |||||
| y = F.broadcast_to(x, (3, 4, 5)) | |||||
| return y | return y | ||||
| f(x1) | f(x1) | ||||