| @@ -26,6 +26,9 @@ from mindspore.ops import operations as P | |||
| # from tests.vm_impl.vm_interface import * | |||
| # from tests.vm_impl import * | |||
| grad_by_list = C.GradOperation('get_by_list', get_by_list=True) | |||
| grad_all = C.GradOperation('get_all', get_all=True) | |||
| def setup_module(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False) | |||
| @@ -86,7 +89,7 @@ def test_while_opt_endless(): | |||
| @ms_function | |||
| def construct(self, *inputs): | |||
| return C.grad_all(self.net)(*inputs) | |||
| return grad_all(self.net)(*inputs) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| @@ -149,7 +152,7 @@ def test_while_with_param_grad_with_const_branch(): | |||
| @ms_function | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| return grad_by_list(self.net, self.weights)(a, b, c) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| @@ -189,7 +192,7 @@ def test_for_while_with_param_grad_with_const_branch(): | |||
| @ms_function | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| return grad_by_list(self.net, self.weights)(a, b, c) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| @@ -226,7 +229,7 @@ def test_for_while_with_param_grad_basic(): | |||
| @ms_function | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| return grad_by_list(self.net, self.weights)(a, b, c) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| @@ -263,7 +266,7 @@ def test_for_while_with_param_grad_normal(): | |||
| @ms_function | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| return grad_by_list(self.net, self.weights)(a, b, c) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| @@ -297,7 +300,7 @@ def test_while_with_param_basic_grad(): | |||
| @ms_function | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| return grad_by_list(self.net, self.weights)(a, b, c) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| @@ -331,7 +334,7 @@ def test_while_with_param_basic_grad_mul(): | |||
| @ms_function | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| return grad_by_list(self.net, self.weights)(a, b, c) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| @@ -366,7 +369,7 @@ def test_while_with_param_basic_grad_two(): | |||
| @ms_function | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| return grad_by_list(self.net, self.weights)(a, b, c) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| @@ -402,7 +405,7 @@ def test_while_with_param_basic_grad_three(): | |||
| @ms_function | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| return grad_by_list(self.net, self.weights)(a, b, c) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| @@ -439,7 +442,7 @@ def test_while_if_with_param_grad(): | |||
| @ms_function | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| return grad_by_list(self.net, self.weights)(a, b, c) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| @@ -472,7 +475,7 @@ def test_while_with_param_grad_not_enter_while(): | |||
| @ms_function | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| return grad_by_list(self.net, self.weights)(a, b, c) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| @@ -534,7 +537,7 @@ def test_with_param_if_by_if_grad_inputs(): | |||
| @ms_function | |||
| def construct(self, *inputs): | |||
| return C.grad_all(self.net)(*inputs) | |||
| return grad_all(self.net)(*inputs) | |||
| if_net = MyIfByIfNet() | |||
| net = GradNet(if_net) | |||
| @@ -568,7 +571,7 @@ def test_with_param_if_by_if_grad_parameter(): | |||
| @ms_function | |||
| def construct(self, *inputs): | |||
| return C.grad_by_list(self.net, self.weights)(*inputs) | |||
| return grad_by_list(self.net, self.weights)(*inputs) | |||
| if_net = MyIfByIfNet() | |||
| net = GradNet(if_net) | |||
| @@ -600,7 +603,7 @@ def test_with_param_if_by_if_grad_param_excute_null(): | |||
| @ms_function | |||
| def construct(self, *inputs): | |||
| return C.grad_by_list(self.net, self.weights)(*inputs) | |||
| return grad_by_list(self.net, self.weights)(*inputs) | |||
| if_net = MyIfByIfNet() | |||
| net = GradNet(if_net) | |||
| @@ -634,7 +637,7 @@ def test_if_by_if_return_inside_grad(): | |||
| @ms_function | |||
| def construct(self, *inputs): | |||
| return C.grad_by_list(self.net, self.weights)(*inputs) | |||
| return grad_by_list(self.net, self.weights)(*inputs) | |||
| if_net = MyIfByIfNet() | |||
| net = GradNet(if_net) | |||