| @@ -1352,10 +1352,11 @@ def roll( | |||||
| if shift_ == 0: | if shift_ == 0: | ||||
| continue | continue | ||||
| size = shp[axis_normalized_] | size = shp[axis_normalized_] | ||||
| if shift_ > 0: | |||||
| a, b = split(out, [size - shift_,], axis=axis_normalized_) | |||||
| shift_normalized_ = 0 if size == 0 else shift_ % size | |||||
| if shift_normalized_ > 0: | |||||
| a, b = split(out, [size - shift_normalized_,], axis=axis_normalized_) | |||||
| else: | else: | ||||
| a, b = split(out, [-shift_,], axis=axis_normalized_) | |||||
| a, b = split(out, [-shift_normalized_,], axis=axis_normalized_) | |||||
| out = concat((b, a), axis=axis_normalized_) | out = concat((b, a), axis=axis_normalized_) | ||||
| if shp_bak is not None: | if shp_bak is not None: | ||||
| out = out.reshape(shp_bak) | out = out.reshape(shp_bak) | ||||
| @@ -806,6 +806,8 @@ def test_tile(shape, reps, is_varnode): | |||||
| [ | [ | ||||
| ((2, 3), 0, None), | ((2, 3), 0, None), | ||||
| ((2, 3), 1, 0), | ((2, 3), 1, 0), | ||||
| ((2, 3), 100, 0), | |||||
| ((2, 3), -100, 0), | |||||
| ((2, 3, 4, 5), (-1, 1), (0, 1)), | ((2, 3, 4, 5), (-1, 1), (0, 1)), | ||||
| ((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)), | ((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)), | ||||
| ], | ], | ||||
| @@ -829,3 +831,24 @@ def test_roll(shape, shifts, axis, is_varnode): | |||||
| opr_test( | opr_test( | ||||
| cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network | cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network | ||||
| ) | ) | ||||
| @pytest.mark.parametrize( | |||||
| "shape, shifts, axis", [((10, 0), 5, 1), ((10, 0), -10, 1),], | |||||
| ) | |||||
| @pytest.mark.parametrize("is_symbolic", [None, True, False]) | |||||
| def test_roll_empty_tensor(shape, shifts, axis, is_symbolic): | |||||
| inp = tensor(np.random.randn(*shape).astype("float32")) | |||||
| def func(inp): | |||||
| return F.roll(inp, shifts, axis) | |||||
| if is_symbolic is not None: | |||||
| func = trace(symbolic=is_symbolic)(func) | |||||
| out_ref = np.roll(inp.numpy(), shifts, axis) | |||||
| for _ in range(3): | |||||
| out = F.roll(inp, shifts, axis) | |||||
| np.testing.assert_equal(out.numpy(), out_ref) | |||||
| if is_symbolic is None: | |||||
| break | |||||
| @@ -1339,8 +1339,10 @@ void Concat::scn_do_execute() { | |||||
| if (real_axis < 0) | if (real_axis < 0) | ||||
| real_axis += in.shape().ndim; | real_axis += in.shape().ndim; | ||||
| end = begin + in.shape().shape[real_axis]; | end = begin + in.shape().shape[real_axis]; | ||||
| out.sub(Slice(begin, end).apply(out.layout(), real_axis)). | |||||
| copy_from_fixlayout(in); | |||||
| if (!in.layout().is_empty()) { | |||||
| out.sub(Slice(begin, end).apply(out.layout(), real_axis)). | |||||
| copy_from_fixlayout(in); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||