Browse Source

check result in test_cont_graph cases

tags/v1.2.0-rc1
liangzelang 4 years ago
parent
commit
81afa9f103
1 changed files with 201 additions and 83 deletions
  1. +201
    -83
      tests/st/control/test_cont_grad.py

+ 201
- 83
tests/st/control/test_cont_grad.py View File

@@ -46,13 +46,17 @@ def test_while_forward():
x[idx, :, 0:2] = max_num x[idx, :, 0:2] = max_num
idx = idx + 1 idx = idx + 1
return x return x
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
net = MyWhileNet() net = MyWhileNet()
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32) end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
graph_output = net(idx, end, x)
#pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)




def test_while_grad(): def test_while_grad():
@@ -76,15 +80,20 @@ def test_while_grad():


def construct(self, *inputs): def construct(self, *inputs):
return grad_all(self.net)(*inputs) return grad_all(self.net)(*inputs)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32) end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)


def test_while_with_param_forward(): def test_while_with_param_forward():
class MyWhileNet(nn.Cell): class MyWhileNet(nn.Cell):
@@ -103,17 +112,21 @@ def test_while_with_param_forward():
out = out + x + self.param out = out + x + self.param
idx = idx + 1 idx = idx + 1
return out return out
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
net = MyWhileNet() net = MyWhileNet()
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32) end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)




def test_while_endless_case(): def test_while_endless_case():
"""endless case when optmization"""
"""endless case when optimization"""
class MyWhileNet(nn.Cell): class MyWhileNet(nn.Cell):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -128,13 +141,17 @@ def test_while_endless_case():
out = out + part out = out + part
idx = idx + 1 idx = idx + 1
return out return out
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
net = MyWhileNet() net = MyWhileNet()
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32) end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)




def test_while_with_param_grad(): def test_while_with_param_grad():
@@ -163,15 +180,18 @@ def test_while_with_param_grad():


def construct(self, a, b, c): def construct(self, a, b, c):
return grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32) end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)


def test_while_with_param_forward_with_const_branch(): def test_while_with_param_forward_with_const_branch():
class MyWhileNet(nn.Cell): class MyWhileNet(nn.Cell):
@@ -191,14 +211,18 @@ def test_while_with_param_forward_with_const_branch():
out = out + idx + self.param out = out + idx + self.param
idx = idx + 1 idx = idx + 1
return out return out
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet() while_net = MyWhileNet()
net = while_net net = while_net
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32) end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)




def test_while_opt_endless(): def test_while_opt_endless():
@@ -228,15 +252,18 @@ def test_while_opt_endless():


def construct(self, *inputs): def construct(self, *inputs):
return grad_all(self.net)(*inputs) return grad_all(self.net)(*inputs)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32) end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32) x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)


def test_no_while_call(): def test_no_while_call():
class MyWhileNet(nn.Cell): class MyWhileNet(nn.Cell):
@@ -254,14 +281,18 @@ def test_no_while_call():
else: else:
out = out + idx + self.param out = out + idx + self.param
return out return out
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet() while_net = MyWhileNet()
net = while_net net = while_net
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32) end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)




def test_while_with_param_grad_with_const_branch(): def test_while_with_param_grad_with_const_branch():
@@ -291,15 +322,18 @@ def test_while_with_param_grad_with_const_branch():


def construct(self, a, b, c): def construct(self, a, b, c):
return grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32) end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)


def test_for_while_with_param_grad_with_const_branch(): def test_for_while_with_param_grad_with_const_branch():
class MyWhileNet(nn.Cell): class MyWhileNet(nn.Cell):
@@ -331,15 +365,18 @@ def test_for_while_with_param_grad_with_const_branch():


def construct(self, a, b, c): def construct(self, a, b, c):
return grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32) end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)


def test_for_while_with_param_grad_basic(): def test_for_while_with_param_grad_basic():
class MyWhileNet(nn.Cell): class MyWhileNet(nn.Cell):
@@ -368,15 +405,18 @@ def test_for_while_with_param_grad_basic():


def construct(self, a, b, c): def construct(self, a, b, c):
return grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32) end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)


def test_for_while_with_param_grad_normal(): def test_for_while_with_param_grad_normal():
class MyWhileNet(nn.Cell): class MyWhileNet(nn.Cell):
@@ -405,15 +445,18 @@ def test_for_while_with_param_grad_normal():


def construct(self, a, b, c): def construct(self, a, b, c):
return grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32) end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)


def test_while_with_param_basic_grad(): def test_while_with_param_basic_grad():
class MyWhileNet(nn.Cell): class MyWhileNet(nn.Cell):
@@ -439,15 +482,18 @@ def test_while_with_param_basic_grad():


def construct(self, a, b, c): def construct(self, a, b, c):
return grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(3), dtype=ms.int32) end = Tensor(np.array(3), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)


def test_while_with_param_basic_grad_mul(): def test_while_with_param_basic_grad_mul():
class MyWhileNet(nn.Cell): class MyWhileNet(nn.Cell):
@@ -473,15 +519,18 @@ def test_while_with_param_basic_grad_mul():


def construct(self, a, b, c): def construct(self, a, b, c):
return grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(3), dtype=ms.int32) end = Tensor(np.array(3), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)


def test_while_with_param_basic_grad_two(): def test_while_with_param_basic_grad_two():
class MyWhileNet(nn.Cell): class MyWhileNet(nn.Cell):
@@ -508,15 +557,19 @@ def test_while_with_param_basic_grad_two():


def construct(self, a, b, c): def construct(self, a, b, c):
return grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(3), dtype=ms.int32) end = Tensor(np.array(3), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)


def test_while_with_param_basic_grad_three(): def test_while_with_param_basic_grad_three():
class MyWhileNet(nn.Cell): class MyWhileNet(nn.Cell):
@@ -544,15 +597,20 @@ def test_while_with_param_basic_grad_three():


def construct(self, a, b, c): def construct(self, a, b, c):
return grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(3), dtype=ms.int32) end = Tensor(np.array(3), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)


def test_while_if_with_param_grad(): def test_while_if_with_param_grad():
class MyWhileNet(nn.Cell): class MyWhileNet(nn.Cell):
@@ -581,15 +639,18 @@ def test_while_if_with_param_grad():


def construct(self, a, b, c): def construct(self, a, b, c):
return grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(3), dtype=ms.int32) end = Tensor(np.array(3), dtype=ms.int32)
x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)


def test_while_with_param_grad_not_enter_while(): def test_while_with_param_grad_not_enter_while():
class MyWhileNet(nn.Cell): class MyWhileNet(nn.Cell):
@@ -614,15 +675,18 @@ def test_while_with_param_grad_not_enter_while():


def construct(self, a, b, c): def construct(self, a, b, c):
return grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
idx = Tensor(np.array(3), dtype=ms.int32) idx = Tensor(np.array(3), dtype=ms.int32)
end = Tensor(np.array(0), dtype=ms.int32) end = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)


def test_with_param_if_by_if_forward(): def test_with_param_if_by_if_forward():
class MyIfByIfNet(nn.Cell): class MyIfByIfNet(nn.Cell):
@@ -643,14 +707,18 @@ def test_with_param_if_by_if_forward():
else: else:
out = out + x*2 out = out + x*2
return out return out
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = if_net net = if_net
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32) end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)




def test_with_param_if_by_if_grad_inputs(): def test_with_param_if_by_if_grad_inputs():
@@ -676,15 +744,20 @@ def test_with_param_if_by_if_grad_inputs():


def construct(self, *inputs): def construct(self, *inputs):
return grad_all(self.net)(*inputs) return grad_all(self.net)(*inputs)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = GradNet(if_net) net = GradNet(if_net)
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(0), dtype=ms.int32) end = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)


def test_with_param_if_by_if_grad_parameter(): def test_with_param_if_by_if_grad_parameter():
class MyIfByIfNet(nn.Cell): class MyIfByIfNet(nn.Cell):
@@ -710,15 +783,18 @@ def test_with_param_if_by_if_grad_parameter():


def construct(self, *inputs): def construct(self, *inputs):
return grad_by_list(self.net, self.weights)(*inputs) return grad_by_list(self.net, self.weights)(*inputs)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = GradNet(if_net) net = GradNet(if_net)
idx = Tensor(np.array(0), dtype=ms.int32) idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32) end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)


def test_with_param_if_by_if_grad_param_excute_null(): def test_with_param_if_by_if_grad_param_excute_null():
class MyIfByIfNet(nn.Cell): class MyIfByIfNet(nn.Cell):
@@ -742,15 +818,18 @@ def test_with_param_if_by_if_grad_param_excute_null():


def construct(self, *inputs): def construct(self, *inputs):
return grad_by_list(self.net, self.weights)(*inputs) return grad_by_list(self.net, self.weights)(*inputs)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = GradNet(if_net) net = GradNet(if_net)
idx = Tensor(np.array(4), dtype=ms.int32) idx = Tensor(np.array(4), dtype=ms.int32)
end = Tensor(np.array(0), dtype=ms.int32) end = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)


def test_if_by_if_return_inside_grad(): def test_if_by_if_return_inside_grad():
class MyIfByIfNet(nn.Cell): class MyIfByIfNet(nn.Cell):
@@ -776,15 +855,18 @@ def test_if_by_if_return_inside_grad():


def construct(self, *inputs): def construct(self, *inputs):
return grad_by_list(self.net, self.weights)(*inputs) return grad_by_list(self.net, self.weights)(*inputs)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = GradNet(if_net) net = GradNet(if_net)
idx = Tensor(np.array(1), dtype=ms.int32) idx = Tensor(np.array(1), dtype=ms.int32)
end = Tensor(np.array(0), dtype=ms.int32) end = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)

graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)


def test_if_by_if_forward(): def test_if_by_if_forward():
class MyIfByIfNet(nn.Cell): class MyIfByIfNet(nn.Cell):
@@ -811,18 +893,22 @@ def test_if_by_if_forward():
a = a * b a = a * b
out = a + b + x out = a + b + x
return out return out
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = if_net net = if_net
idx = Tensor(np.array(2), dtype=ms.float32) idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32) end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(4), dtype=ms.float32) x = Tensor(np.array(4), dtype=ms.float32)
net(idx, end, x)
graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)




def test_if_by_if_forward_control_tuple_switch(): def test_if_by_if_forward_control_tuple_switch():
"""tuple_get from swtich op will generate new switch inside to eliminate tuple_get"""
"""tuple_get from switch op will generate new switch inside to eliminate tuple_get"""
class Branch3Net(nn.Cell): class Branch3Net(nn.Cell):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -871,14 +957,18 @@ def test_if_by_if_forward_control_tuple_switch():
a = a * b a = a * b
out = a + b + x out = a + b + x
return out return out
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = if_net net = if_net
idx = Tensor(np.array(2), dtype=ms.float32) idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32) end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32) x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)






@@ -932,14 +1022,18 @@ def test_if_by_if_forward_control_inside_net():
a = self.sub(a, b) a = self.sub(a, b)
out = self.net(a, b, x) out = self.net(a, b, x)
return out return out
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = if_net net = if_net
idx = Tensor(np.array(2), dtype=ms.float32) idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32) end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32) x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)






@@ -968,14 +1062,18 @@ def test_if_by_if_forward_use_namespace():
a = a * b a = a * b
out = a + b + x out = a + b + x
return out return out
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = if_net net = if_net
idx = Tensor(np.array(2), dtype=ms.float32) idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32) end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32) x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)




def test_if_by_if_forward_use_global_op(): def test_if_by_if_forward_use_global_op():
@@ -1007,14 +1105,18 @@ def test_if_by_if_forward_use_global_op():
a = a * b a = a * b
out = a + b + x out = a + b + x
return out return out
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = if_net net = if_net
idx = Tensor(np.array(2), dtype=ms.float32) idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32) end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32) x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)




def test_for_with_if_by_if_forward(): def test_for_with_if_by_if_forward():
@@ -1033,14 +1135,18 @@ def test_for_with_if_by_if_forward():
a = a * b a = a * b
out = a + b + x out = a + b + x
return out return out
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = if_net net = if_net
idx = Tensor(np.array(2), dtype=ms.float32) idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32) end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32) x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)






@@ -1062,14 +1168,18 @@ def test_for_with_if_by_if_forward_namespace():
a = a * b a = a * b
out = a + b + x out = a + b + x
return out return out
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = if_net net = if_net
idx = Tensor(np.array(2), dtype=ms.float32) idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32) end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32) x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)






@@ -1102,14 +1212,18 @@ def test_if_by_if_forward_const_branch_inner():
a = a * b a = a * b
out = a + b + x out = a + b + x
return out return out
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = if_net net = if_net
idx = Tensor(np.array(2), dtype=ms.float32) idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32) end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32) x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)






@@ -1143,14 +1257,18 @@ def test_if_by_if_forward_all_const_branch():
a = a * b a = a * b
out = a + b + x out = a + b + x
return out return out
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = if_net net = if_net
idx = Tensor(np.array(2), dtype=ms.float32) idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32) end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32) x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)




@pytest.mark.level0 @pytest.mark.level0


Loading…
Cancel
Save