| @@ -57,6 +57,28 @@ def _transpose(data, axes): | |||||
| def _broadcast(inp, shape): | def _broadcast(inp, shape): | ||||
| def valid_broadcast(src, tar): | |||||
| def failed(): | |||||
| raise ValueError( | |||||
| "the input shape {} can not be broadcasted to target shape {}".format( | |||||
| src, tar | |||||
| ) | |||||
| ) | |||||
| if isinstance(src, (Tensor, TensorWrapperBase)): | |||||
| src = src.numpy() | |||||
| if isinstance(tar, (Tensor, TensorWrapperBase)): | |||||
| tar = tar.numpy() | |||||
| if len(src) > len(tar): | |||||
| failed() | |||||
| for i in range(min(len(src), len(tar))): | |||||
| if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]: | |||||
| failed() | |||||
| valid_broadcast(inp.shape, shape) | |||||
| shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | ||||
| (result,) = apply(builtin.Broadcast(), inp, shape) | (result,) = apply(builtin.Broadcast(), inp, shape) | ||||
| return result | return result | ||||
| @@ -240,7 +240,7 @@ def test_broadcast(): | |||||
| output1_shape = (30, 20, 30) | output1_shape = (30, 20, 30) | ||||
| data1 = np.random.random(input1_shape).astype(np.float32) | data1 = np.random.random(input1_shape).astype(np.float32) | ||||
| input2_shape = (10, 20) | |||||
| input2_shape = (10, 1) | |||||
| output2_shape = (20, 10, 20) | output2_shape = (20, 10, 20) | ||||
| data2 = np.random.random(input2_shape).astype(np.float32) | data2 = np.random.random(input2_shape).astype(np.float32) | ||||
| @@ -253,6 +253,16 @@ def test_broadcast(): | |||||
| ] | ] | ||||
| opr_test(cases, F.broadcast, compare_fn=compare_fn) | opr_test(cases, F.broadcast, compare_fn=compare_fn) | ||||
| x = F.ones((2, 1, 3)) | |||||
| with pytest.raises(ValueError): | |||||
| F.broadcast(x, (2, 3, 4)) | |||||
| with pytest.raises(ValueError): | |||||
| F.broadcast(x, (4, 1, 3)) | |||||
| with pytest.raises(ValueError): | |||||
| F.broadcast(x, (1, 3)) | |||||
| def test_utils_astensor1d(): | def test_utils_astensor1d(): | ||||
| reference = tensor(0) | reference = tensor(0) | ||||
| @@ -340,3 +340,20 @@ def test_raise_on_trace(): | |||||
| step_count += 1 | step_count += 1 | ||||
| assert catch_count == 1 | assert catch_count == 1 | ||||
| def test_trace_broadcast(): | |||||
| for symbolic in [False, True]: | |||||
| set_tensor_shape(True) | |||||
| x1 = tensor(np.random.randn(3, 1, 1)) | |||||
| x2 = tensor(np.random.randn(1, 4, 1)) | |||||
| x3 = tensor(np.random.randn(1, 1, 5)) | |||||
| @trace(symbolic=symbolic, capture_as_const=True) | |||||
| def f(x): | |||||
| y = x.broadcast((3, 4, 5)) | |||||
| return y | |||||
| f(x1) | |||||
| f(x2) | |||||
| f(x3) | |||||