|
|
@@ -102,16 +102,41 @@ class ControlIfbyIfbyIf(nn.Cell): |
|
|
class ControlMixedWhileIf(nn.Cell): |
|
|
class ControlMixedWhileIf(nn.Cell): |
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
|
|
|
self.assign = op.Assign() |
|
|
|
|
|
self.var = Parameter(initializer(1, (1), mstype.float32), name="var") |
|
|
|
|
|
|
|
|
|
|
|
def construct(self, x, y, z, c2, c4): |
|
|
|
|
|
out = self.assign(self.var, c4) |
|
|
|
|
|
while x < c2: |
|
|
|
|
|
y = self.assign(self.var, c4) |
|
|
|
|
|
while y < c2 and x < c2: |
|
|
|
|
|
if 2 * y < c2: |
|
|
|
|
|
y = y + 2 |
|
|
|
|
|
else: |
|
|
|
|
|
y = y + 1 |
|
|
|
|
|
out = out + y |
|
|
|
|
|
z = self.assign(self.var, c4) |
|
|
|
|
|
while z < c2: |
|
|
|
|
|
z = z + 1 |
|
|
|
|
|
out = out + z |
|
|
|
|
|
x = x + 1 |
|
|
|
|
|
out = out + x |
|
|
|
|
|
while x < 2 * c2: |
|
|
|
|
|
y = self.assign(self.var, c4) |
|
|
|
|
|
x = x + 1 |
|
|
|
|
|
while y < c2: |
|
|
|
|
|
z = self.assign(self.var, c4) |
|
|
|
|
|
while z < c2: |
|
|
|
|
|
z = z + 1 |
|
|
|
|
|
if x < c2: |
|
|
|
|
|
y = y - 1 |
|
|
|
|
|
else: |
|
|
|
|
|
y = y + 1 |
|
|
|
|
|
out = out + z |
|
|
|
|
|
out = out + y |
|
|
|
|
|
out = out + x |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
def construct(self, x, y): |
|
|
|
|
|
y = y + 4 |
|
|
|
|
|
while x < y: |
|
|
|
|
|
if 2 * x < y: |
|
|
|
|
|
x = x + 1 |
|
|
|
|
|
else: |
|
|
|
|
|
x = x + 2 |
|
|
|
|
|
x = x + 3 |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
@pytest.mark.level0 |
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
@@ -130,6 +155,7 @@ def test_simple_if(): |
|
|
expect = input2 * 3 * 3 * 2 + input1 |
|
|
expect = input2 * 3 * 3 * 2 + input1 |
|
|
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) |
|
|
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
@pytest.mark.level0 |
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
@@ -145,6 +171,7 @@ def test_simple_if_with_assign(): |
|
|
expect = input_data |
|
|
expect = input_data |
|
|
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) |
|
|
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
@pytest.mark.level0 |
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
@@ -158,6 +185,7 @@ def test_if_in_if(): |
|
|
expect = x + y + 3 |
|
|
expect = x + y + 3 |
|
|
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) |
|
|
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
@pytest.mark.level0 |
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
@@ -175,6 +203,7 @@ def test_if_by_if_by_if(): |
|
|
expect = input_data * 3 * 2 * 2 * 2 |
|
|
expect = input_data * 3 * 2 * 2 * 2 |
|
|
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) |
|
|
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
@pytest.mark.level0 |
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
@@ -183,7 +212,10 @@ def test_mixed_while_if(): |
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
x = np.array(2).astype(np.int32) |
|
|
x = np.array(2).astype(np.int32) |
|
|
y = np.array(14).astype(np.int32) |
|
|
y = np.array(14).astype(np.int32) |
|
|
|
|
|
z = np.array(1).astype(np.int32) |
|
|
|
|
|
c2 = Tensor([14], mstype.int32) |
|
|
|
|
|
c4 = Tensor([0], mstype.int32) |
|
|
net = ControlMixedWhileIf() |
|
|
net = ControlMixedWhileIf() |
|
|
output = net(Tensor(x), Tensor(y)) |
|
|
|
|
|
expect = np.array(22).astype(np.int32) |
|
|
|
|
|
|
|
|
output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4) |
|
|
|
|
|
expect = np.array(3318).astype(np.int32) |
|
|
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) |
|
|
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) |