|
|
|
@@ -1429,6 +1429,33 @@ def test_if_cast(): |
|
|
|
np.testing.assert_array_equal(r1.asnumpy(), expect.asnumpy()) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
|
@pytest.mark.env_onecard |
|
|
|
def test_while_forward(): |
|
|
|
class MyWhileNet(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.max = P.ReduceMax() |
|
|
|
|
|
|
|
def construct(self, idx, end, x): |
|
|
|
while idx < end: |
|
|
|
part = x[idx, :, :] |
|
|
|
max_num = self.max(part) |
|
|
|
x[idx, :, 0:2] = max_num |
|
|
|
idx = idx + 1 |
|
|
|
return x |
|
|
|
|
|
|
|
net = MyWhileNet() |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
end = Tensor(np.array(2), dtype=ms.int32) |
|
|
|
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) |
|
|
|
output = net(idx, end, x) |
|
|
|
expect = np.array([[[3, 3], [3, 3]], [[7, 7], [7, 7]]], dtype=np.int32) |
|
|
|
assert np.allclose(output.asnumpy(), expect, 0.0001, 0.0001) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip(reason="not supported yet") |
|
|
|
def test_multi_add_assign(): |
|
|
|
class Net(Cell): |
|
|
|
|