| @@ -46,13 +46,17 @@ def test_while_forward(): | |||||
| x[idx, :, 0:2] = max_num | x[idx, :, 0:2] = max_num | ||||
| idx = idx + 1 | idx = idx + 1 | ||||
| return x | return x | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| net = MyWhileNet() | net = MyWhileNet() | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(2), dtype=ms.int32) | end = Tensor(np.array(2), dtype=ms.int32) | ||||
| x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| #pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) | |||||
| def test_while_grad(): | def test_while_grad(): | ||||
| @@ -76,15 +80,20 @@ def test_while_grad(): | |||||
| def construct(self, *inputs): | def construct(self, *inputs): | ||||
| return grad_all(self.net)(*inputs) | return grad_all(self.net)(*inputs) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| while_net = MyWhileNet() | while_net = MyWhileNet() | ||||
| net = GradNet(while_net) | net = GradNet(while_net) | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(2), dtype=ms.int32) | end = Tensor(np.array(2), dtype=ms.int32) | ||||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001) | |||||
| assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001) | |||||
| def test_while_with_param_forward(): | def test_while_with_param_forward(): | ||||
| class MyWhileNet(nn.Cell): | class MyWhileNet(nn.Cell): | ||||
| @@ -103,17 +112,21 @@ def test_while_with_param_forward(): | |||||
| out = out + x + self.param | out = out + x + self.param | ||||
| idx = idx + 1 | idx = idx + 1 | ||||
| return out | return out | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| net = MyWhileNet() | net = MyWhileNet() | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(2), dtype=ms.int32) | end = Tensor(np.array(2), dtype=ms.int32) | ||||
| x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) | |||||
| def test_while_endless_case(): | def test_while_endless_case(): | ||||
| """endless case when optmization""" | |||||
| """endless case when optimization""" | |||||
| class MyWhileNet(nn.Cell): | class MyWhileNet(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -128,13 +141,17 @@ def test_while_endless_case(): | |||||
| out = out + part | out = out + part | ||||
| idx = idx + 1 | idx = idx + 1 | ||||
| return out | return out | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| net = MyWhileNet() | net = MyWhileNet() | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(2), dtype=ms.int32) | end = Tensor(np.array(2), dtype=ms.int32) | ||||
| x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) | |||||
| def test_while_with_param_grad(): | def test_while_with_param_grad(): | ||||
| @@ -163,15 +180,18 @@ def test_while_with_param_grad(): | |||||
| def construct(self, a, b, c): | def construct(self, a, b, c): | ||||
| return grad_by_list(self.net, self.weights)(a, b, c) | return grad_by_list(self.net, self.weights)(a, b, c) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| while_net = MyWhileNet() | while_net = MyWhileNet() | ||||
| net = GradNet(while_net) | net = GradNet(while_net) | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(2), dtype=ms.int32) | end = Tensor(np.array(2), dtype=ms.int32) | ||||
| x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| def test_while_with_param_forward_with_const_branch(): | def test_while_with_param_forward_with_const_branch(): | ||||
| class MyWhileNet(nn.Cell): | class MyWhileNet(nn.Cell): | ||||
| @@ -191,14 +211,18 @@ def test_while_with_param_forward_with_const_branch(): | |||||
| out = out + idx + self.param | out = out + idx + self.param | ||||
| idx = idx + 1 | idx = idx + 1 | ||||
| return out | return out | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| while_net = MyWhileNet() | while_net = MyWhileNet() | ||||
| net = while_net | net = while_net | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(4), dtype=ms.int32) | end = Tensor(np.array(4), dtype=ms.int32) | ||||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) | |||||
| def test_while_opt_endless(): | def test_while_opt_endless(): | ||||
| @@ -228,15 +252,18 @@ def test_while_opt_endless(): | |||||
| def construct(self, *inputs): | def construct(self, *inputs): | ||||
| return grad_all(self.net)(*inputs) | return grad_all(self.net)(*inputs) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| while_net = MyWhileNet() | while_net = MyWhileNet() | ||||
| net = GradNet(while_net) | net = GradNet(while_net) | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(4), dtype=ms.int32) | end = Tensor(np.array(4), dtype=ms.int32) | ||||
| x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32) | x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| def test_no_while_call(): | def test_no_while_call(): | ||||
| class MyWhileNet(nn.Cell): | class MyWhileNet(nn.Cell): | ||||
| @@ -254,14 +281,18 @@ def test_no_while_call(): | |||||
| else: | else: | ||||
| out = out + idx + self.param | out = out + idx + self.param | ||||
| return out | return out | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| while_net = MyWhileNet() | while_net = MyWhileNet() | ||||
| net = while_net | net = while_net | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(4), dtype=ms.int32) | end = Tensor(np.array(4), dtype=ms.int32) | ||||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) | |||||
| def test_while_with_param_grad_with_const_branch(): | def test_while_with_param_grad_with_const_branch(): | ||||
| @@ -291,15 +322,18 @@ def test_while_with_param_grad_with_const_branch(): | |||||
| def construct(self, a, b, c): | def construct(self, a, b, c): | ||||
| return grad_by_list(self.net, self.weights)(a, b, c) | return grad_by_list(self.net, self.weights)(a, b, c) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| while_net = MyWhileNet() | while_net = MyWhileNet() | ||||
| net = GradNet(while_net) | net = GradNet(while_net) | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(4), dtype=ms.int32) | end = Tensor(np.array(4), dtype=ms.int32) | ||||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| def test_for_while_with_param_grad_with_const_branch(): | def test_for_while_with_param_grad_with_const_branch(): | ||||
| class MyWhileNet(nn.Cell): | class MyWhileNet(nn.Cell): | ||||
| @@ -331,15 +365,18 @@ def test_for_while_with_param_grad_with_const_branch(): | |||||
| def construct(self, a, b, c): | def construct(self, a, b, c): | ||||
| return grad_by_list(self.net, self.weights)(a, b, c) | return grad_by_list(self.net, self.weights)(a, b, c) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| while_net = MyWhileNet() | while_net = MyWhileNet() | ||||
| net = GradNet(while_net) | net = GradNet(while_net) | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(4), dtype=ms.int32) | end = Tensor(np.array(4), dtype=ms.int32) | ||||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| def test_for_while_with_param_grad_basic(): | def test_for_while_with_param_grad_basic(): | ||||
| class MyWhileNet(nn.Cell): | class MyWhileNet(nn.Cell): | ||||
| @@ -368,15 +405,18 @@ def test_for_while_with_param_grad_basic(): | |||||
| def construct(self, a, b, c): | def construct(self, a, b, c): | ||||
| return grad_by_list(self.net, self.weights)(a, b, c) | return grad_by_list(self.net, self.weights)(a, b, c) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| while_net = MyWhileNet() | while_net = MyWhileNet() | ||||
| net = GradNet(while_net) | net = GradNet(while_net) | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(4), dtype=ms.int32) | end = Tensor(np.array(4), dtype=ms.int32) | ||||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| def test_for_while_with_param_grad_normal(): | def test_for_while_with_param_grad_normal(): | ||||
| class MyWhileNet(nn.Cell): | class MyWhileNet(nn.Cell): | ||||
| @@ -405,15 +445,18 @@ def test_for_while_with_param_grad_normal(): | |||||
| def construct(self, a, b, c): | def construct(self, a, b, c): | ||||
| return grad_by_list(self.net, self.weights)(a, b, c) | return grad_by_list(self.net, self.weights)(a, b, c) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| while_net = MyWhileNet() | while_net = MyWhileNet() | ||||
| net = GradNet(while_net) | net = GradNet(while_net) | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(4), dtype=ms.int32) | end = Tensor(np.array(4), dtype=ms.int32) | ||||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| def test_while_with_param_basic_grad(): | def test_while_with_param_basic_grad(): | ||||
| class MyWhileNet(nn.Cell): | class MyWhileNet(nn.Cell): | ||||
| @@ -439,15 +482,18 @@ def test_while_with_param_basic_grad(): | |||||
| def construct(self, a, b, c): | def construct(self, a, b, c): | ||||
| return grad_by_list(self.net, self.weights)(a, b, c) | return grad_by_list(self.net, self.weights)(a, b, c) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| while_net = MyWhileNet() | while_net = MyWhileNet() | ||||
| net = GradNet(while_net) | net = GradNet(while_net) | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(3), dtype=ms.int32) | end = Tensor(np.array(3), dtype=ms.int32) | ||||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| def test_while_with_param_basic_grad_mul(): | def test_while_with_param_basic_grad_mul(): | ||||
| class MyWhileNet(nn.Cell): | class MyWhileNet(nn.Cell): | ||||
| @@ -473,15 +519,18 @@ def test_while_with_param_basic_grad_mul(): | |||||
| def construct(self, a, b, c): | def construct(self, a, b, c): | ||||
| return grad_by_list(self.net, self.weights)(a, b, c) | return grad_by_list(self.net, self.weights)(a, b, c) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| while_net = MyWhileNet() | while_net = MyWhileNet() | ||||
| net = GradNet(while_net) | net = GradNet(while_net) | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(3), dtype=ms.int32) | end = Tensor(np.array(3), dtype=ms.int32) | ||||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| def test_while_with_param_basic_grad_two(): | def test_while_with_param_basic_grad_two(): | ||||
| class MyWhileNet(nn.Cell): | class MyWhileNet(nn.Cell): | ||||
| @@ -508,15 +557,19 @@ def test_while_with_param_basic_grad_two(): | |||||
| def construct(self, a, b, c): | def construct(self, a, b, c): | ||||
| return grad_by_list(self.net, self.weights)(a, b, c) | return grad_by_list(self.net, self.weights)(a, b, c) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| while_net = MyWhileNet() | while_net = MyWhileNet() | ||||
| net = GradNet(while_net) | net = GradNet(while_net) | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(3), dtype=ms.int32) | end = Tensor(np.array(3), dtype=ms.int32) | ||||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001) | |||||
| def test_while_with_param_basic_grad_three(): | def test_while_with_param_basic_grad_three(): | ||||
| class MyWhileNet(nn.Cell): | class MyWhileNet(nn.Cell): | ||||
| @@ -544,15 +597,20 @@ def test_while_with_param_basic_grad_three(): | |||||
| def construct(self, a, b, c): | def construct(self, a, b, c): | ||||
| return grad_by_list(self.net, self.weights)(a, b, c) | return grad_by_list(self.net, self.weights)(a, b, c) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| while_net = MyWhileNet() | while_net = MyWhileNet() | ||||
| net = GradNet(while_net) | net = GradNet(while_net) | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(3), dtype=ms.int32) | end = Tensor(np.array(3), dtype=ms.int32) | ||||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001) | |||||
| assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001) | |||||
| def test_while_if_with_param_grad(): | def test_while_if_with_param_grad(): | ||||
| class MyWhileNet(nn.Cell): | class MyWhileNet(nn.Cell): | ||||
| @@ -581,15 +639,18 @@ def test_while_if_with_param_grad(): | |||||
| def construct(self, a, b, c): | def construct(self, a, b, c): | ||||
| return grad_by_list(self.net, self.weights)(a, b, c) | return grad_by_list(self.net, self.weights)(a, b, c) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| while_net = MyWhileNet() | while_net = MyWhileNet() | ||||
| net = GradNet(while_net) | net = GradNet(while_net) | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(3), dtype=ms.int32) | end = Tensor(np.array(3), dtype=ms.int32) | ||||
| x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) | x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| def test_while_with_param_grad_not_enter_while(): | def test_while_with_param_grad_not_enter_while(): | ||||
| class MyWhileNet(nn.Cell): | class MyWhileNet(nn.Cell): | ||||
| @@ -614,15 +675,18 @@ def test_while_with_param_grad_not_enter_while(): | |||||
| def construct(self, a, b, c): | def construct(self, a, b, c): | ||||
| return grad_by_list(self.net, self.weights)(a, b, c) | return grad_by_list(self.net, self.weights)(a, b, c) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| while_net = MyWhileNet() | while_net = MyWhileNet() | ||||
| net = GradNet(while_net) | net = GradNet(while_net) | ||||
| idx = Tensor(np.array(3), dtype=ms.int32) | idx = Tensor(np.array(3), dtype=ms.int32) | ||||
| end = Tensor(np.array(0), dtype=ms.int32) | end = Tensor(np.array(0), dtype=ms.int32) | ||||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| def test_with_param_if_by_if_forward(): | def test_with_param_if_by_if_forward(): | ||||
| class MyIfByIfNet(nn.Cell): | class MyIfByIfNet(nn.Cell): | ||||
| @@ -643,14 +707,18 @@ def test_with_param_if_by_if_forward(): | |||||
| else: | else: | ||||
| out = out + x*2 | out = out + x*2 | ||||
| return out | return out | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| if_net = MyIfByIfNet() | if_net = MyIfByIfNet() | ||||
| net = if_net | net = if_net | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(4), dtype=ms.int32) | end = Tensor(np.array(4), dtype=ms.int32) | ||||
| x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) | x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) | |||||
| def test_with_param_if_by_if_grad_inputs(): | def test_with_param_if_by_if_grad_inputs(): | ||||
| @@ -676,15 +744,20 @@ def test_with_param_if_by_if_grad_inputs(): | |||||
| def construct(self, *inputs): | def construct(self, *inputs): | ||||
| return grad_all(self.net)(*inputs) | return grad_all(self.net)(*inputs) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| if_net = MyIfByIfNet() | if_net = MyIfByIfNet() | ||||
| net = GradNet(if_net) | net = GradNet(if_net) | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(0), dtype=ms.int32) | end = Tensor(np.array(0), dtype=ms.int32) | ||||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001) | |||||
| assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001) | |||||
| def test_with_param_if_by_if_grad_parameter(): | def test_with_param_if_by_if_grad_parameter(): | ||||
| class MyIfByIfNet(nn.Cell): | class MyIfByIfNet(nn.Cell): | ||||
| @@ -710,15 +783,18 @@ def test_with_param_if_by_if_grad_parameter(): | |||||
| def construct(self, *inputs): | def construct(self, *inputs): | ||||
| return grad_by_list(self.net, self.weights)(*inputs) | return grad_by_list(self.net, self.weights)(*inputs) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| if_net = MyIfByIfNet() | if_net = MyIfByIfNet() | ||||
| net = GradNet(if_net) | net = GradNet(if_net) | ||||
| idx = Tensor(np.array(0), dtype=ms.int32) | idx = Tensor(np.array(0), dtype=ms.int32) | ||||
| end = Tensor(np.array(2), dtype=ms.int32) | end = Tensor(np.array(2), dtype=ms.int32) | ||||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| def test_with_param_if_by_if_grad_param_excute_null(): | def test_with_param_if_by_if_grad_param_excute_null(): | ||||
| class MyIfByIfNet(nn.Cell): | class MyIfByIfNet(nn.Cell): | ||||
| @@ -742,15 +818,18 @@ def test_with_param_if_by_if_grad_param_excute_null(): | |||||
| def construct(self, *inputs): | def construct(self, *inputs): | ||||
| return grad_by_list(self.net, self.weights)(*inputs) | return grad_by_list(self.net, self.weights)(*inputs) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| if_net = MyIfByIfNet() | if_net = MyIfByIfNet() | ||||
| net = GradNet(if_net) | net = GradNet(if_net) | ||||
| idx = Tensor(np.array(4), dtype=ms.int32) | idx = Tensor(np.array(4), dtype=ms.int32) | ||||
| end = Tensor(np.array(0), dtype=ms.int32) | end = Tensor(np.array(0), dtype=ms.int32) | ||||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| def test_if_by_if_return_inside_grad(): | def test_if_by_if_return_inside_grad(): | ||||
| class MyIfByIfNet(nn.Cell): | class MyIfByIfNet(nn.Cell): | ||||
| @@ -776,15 +855,18 @@ def test_if_by_if_return_inside_grad(): | |||||
| def construct(self, *inputs): | def construct(self, *inputs): | ||||
| return grad_by_list(self.net, self.weights)(*inputs) | return grad_by_list(self.net, self.weights)(*inputs) | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| if_net = MyIfByIfNet() | if_net = MyIfByIfNet() | ||||
| net = GradNet(if_net) | net = GradNet(if_net) | ||||
| idx = Tensor(np.array(1), dtype=ms.int32) | idx = Tensor(np.array(1), dtype=ms.int32) | ||||
| end = Tensor(np.array(0), dtype=ms.int32) | end = Tensor(np.array(0), dtype=ms.int32) | ||||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) | |||||
| def test_if_by_if_forward(): | def test_if_by_if_forward(): | ||||
| class MyIfByIfNet(nn.Cell): | class MyIfByIfNet(nn.Cell): | ||||
| @@ -811,18 +893,22 @@ def test_if_by_if_forward(): | |||||
| a = a * b | a = a * b | ||||
| out = a + b + x | out = a + b + x | ||||
| return out | return out | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| if_net = MyIfByIfNet() | if_net = MyIfByIfNet() | ||||
| net = if_net | net = if_net | ||||
| idx = Tensor(np.array(2), dtype=ms.float32) | idx = Tensor(np.array(2), dtype=ms.float32) | ||||
| end = Tensor(np.array(3), dtype=ms.float32) | end = Tensor(np.array(3), dtype=ms.float32) | ||||
| x = Tensor(np.array(4), dtype=ms.float32) | x = Tensor(np.array(4), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) | |||||
| def test_if_by_if_forward_control_tuple_switch(): | def test_if_by_if_forward_control_tuple_switch(): | ||||
| """tuple_get from swtich op will generate new switch inside to eliminate tuple_get""" | |||||
| """tuple_get from switch op will generate new switch inside to eliminate tuple_get""" | |||||
| class Branch3Net(nn.Cell): | class Branch3Net(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -871,14 +957,18 @@ def test_if_by_if_forward_control_tuple_switch(): | |||||
| a = a * b | a = a * b | ||||
| out = a + b + x | out = a + b + x | ||||
| return out | return out | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| if_net = MyIfByIfNet() | if_net = MyIfByIfNet() | ||||
| net = if_net | net = if_net | ||||
| idx = Tensor(np.array(2), dtype=ms.float32) | idx = Tensor(np.array(2), dtype=ms.float32) | ||||
| end = Tensor(np.array(3), dtype=ms.float32) | end = Tensor(np.array(3), dtype=ms.float32) | ||||
| x = Tensor(np.array(0), dtype=ms.float32) | x = Tensor(np.array(0), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) | |||||
| @@ -932,14 +1022,18 @@ def test_if_by_if_forward_control_inside_net(): | |||||
| a = self.sub(a, b) | a = self.sub(a, b) | ||||
| out = self.net(a, b, x) | out = self.net(a, b, x) | ||||
| return out | return out | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| if_net = MyIfByIfNet() | if_net = MyIfByIfNet() | ||||
| net = if_net | net = if_net | ||||
| idx = Tensor(np.array(2), dtype=ms.float32) | idx = Tensor(np.array(2), dtype=ms.float32) | ||||
| end = Tensor(np.array(3), dtype=ms.float32) | end = Tensor(np.array(3), dtype=ms.float32) | ||||
| x = Tensor(np.array(0), dtype=ms.float32) | x = Tensor(np.array(0), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) | |||||
| @@ -968,14 +1062,18 @@ def test_if_by_if_forward_use_namespace(): | |||||
| a = a * b | a = a * b | ||||
| out = a + b + x | out = a + b + x | ||||
| return out | return out | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| if_net = MyIfByIfNet() | if_net = MyIfByIfNet() | ||||
| net = if_net | net = if_net | ||||
| idx = Tensor(np.array(2), dtype=ms.float32) | idx = Tensor(np.array(2), dtype=ms.float32) | ||||
| end = Tensor(np.array(3), dtype=ms.float32) | end = Tensor(np.array(3), dtype=ms.float32) | ||||
| x = Tensor(np.array(0), dtype=ms.float32) | x = Tensor(np.array(0), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) | |||||
| def test_if_by_if_forward_use_global_op(): | def test_if_by_if_forward_use_global_op(): | ||||
| @@ -1007,14 +1105,18 @@ def test_if_by_if_forward_use_global_op(): | |||||
| a = a * b | a = a * b | ||||
| out = a + b + x | out = a + b + x | ||||
| return out | return out | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| if_net = MyIfByIfNet() | if_net = MyIfByIfNet() | ||||
| net = if_net | net = if_net | ||||
| idx = Tensor(np.array(2), dtype=ms.float32) | idx = Tensor(np.array(2), dtype=ms.float32) | ||||
| end = Tensor(np.array(3), dtype=ms.float32) | end = Tensor(np.array(3), dtype=ms.float32) | ||||
| x = Tensor(np.array(0), dtype=ms.float32) | x = Tensor(np.array(0), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) | |||||
| def test_for_with_if_by_if_forward(): | def test_for_with_if_by_if_forward(): | ||||
| @@ -1033,14 +1135,18 @@ def test_for_with_if_by_if_forward(): | |||||
| a = a * b | a = a * b | ||||
| out = a + b + x | out = a + b + x | ||||
| return out | return out | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| if_net = MyIfByIfNet() | if_net = MyIfByIfNet() | ||||
| net = if_net | net = if_net | ||||
| idx = Tensor(np.array(2), dtype=ms.float32) | idx = Tensor(np.array(2), dtype=ms.float32) | ||||
| end = Tensor(np.array(3), dtype=ms.float32) | end = Tensor(np.array(3), dtype=ms.float32) | ||||
| x = Tensor(np.array(0), dtype=ms.float32) | x = Tensor(np.array(0), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) | |||||
| @@ -1062,14 +1168,18 @@ def test_for_with_if_by_if_forward_namespace(): | |||||
| a = a * b | a = a * b | ||||
| out = a + b + x | out = a + b + x | ||||
| return out | return out | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| if_net = MyIfByIfNet() | if_net = MyIfByIfNet() | ||||
| net = if_net | net = if_net | ||||
| idx = Tensor(np.array(2), dtype=ms.float32) | idx = Tensor(np.array(2), dtype=ms.float32) | ||||
| end = Tensor(np.array(3), dtype=ms.float32) | end = Tensor(np.array(3), dtype=ms.float32) | ||||
| x = Tensor(np.array(0), dtype=ms.float32) | x = Tensor(np.array(0), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) | |||||
| @@ -1102,14 +1212,18 @@ def test_if_by_if_forward_const_branch_inner(): | |||||
| a = a * b | a = a * b | ||||
| out = a + b + x | out = a + b + x | ||||
| return out | return out | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| if_net = MyIfByIfNet() | if_net = MyIfByIfNet() | ||||
| net = if_net | net = if_net | ||||
| idx = Tensor(np.array(2), dtype=ms.float32) | idx = Tensor(np.array(2), dtype=ms.float32) | ||||
| end = Tensor(np.array(3), dtype=ms.float32) | end = Tensor(np.array(3), dtype=ms.float32) | ||||
| x = Tensor(np.array(0), dtype=ms.float32) | x = Tensor(np.array(0), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) | |||||
| @@ -1143,14 +1257,18 @@ def test_if_by_if_forward_all_const_branch(): | |||||
| a = a * b | a = a * b | ||||
| out = a + b + x | out = a + b + x | ||||
| return out | return out | ||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| if_net = MyIfByIfNet() | if_net = MyIfByIfNet() | ||||
| net = if_net | net = if_net | ||||
| idx = Tensor(np.array(2), dtype=ms.float32) | idx = Tensor(np.array(2), dtype=ms.float32) | ||||
| end = Tensor(np.array(3), dtype=ms.float32) | end = Tensor(np.array(3), dtype=ms.float32) | ||||
| x = Tensor(np.array(0), dtype=ms.float32) | x = Tensor(np.array(0), dtype=ms.float32) | ||||
| net(idx, end, x) | |||||
| graph_output = net(idx, end, x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| pynative_output = net(idx, end, x) | |||||
| assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||