|
|
|
@@ -170,3 +170,52 @@ def test_for_in_if_03(): |
|
|
|
|
|
|
|
assert graph_forward_res == pynative_forward_res |
|
|
|
assert graph_backward_res == pynative_backward_res |
|
|
|
|
|
|
|
|
|
|
|
def test_for_in_if_04(): |
|
|
|
class ForInIfNet(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.param_a = Parameter(Tensor(5, mstype.int32), name='a') |
|
|
|
self.param_b = Parameter(Tensor(4, mstype.int32), name='b') |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
out = self.param_a |
|
|
|
x = self.func(x) |
|
|
|
out *= x |
|
|
|
return out |
|
|
|
|
|
|
|
def func(self, x): |
|
|
|
if self.param_a > self.param_b: |
|
|
|
for _ in range(0, 4): |
|
|
|
self.param_a += 1 |
|
|
|
self.param_b -= 3 |
|
|
|
self.param_b += 10 |
|
|
|
return x |
|
|
|
|
|
|
|
class GradNet(nn.Cell): |
|
|
|
def __init__(self, net): |
|
|
|
super(GradNet, self).__init__() |
|
|
|
self.net = net |
|
|
|
|
|
|
|
def construct(self, *inputs): |
|
|
|
return grad_all(self.net)(*inputs) |
|
|
|
|
|
|
|
x = Tensor(5, mstype.int32) |
|
|
|
|
|
|
|
# graph mode |
|
|
|
context.set_context(mode=context.GRAPH_MODE) |
|
|
|
for_in_if_net = ForInIfNet() |
|
|
|
net = GradNet(for_in_if_net) |
|
|
|
graph_forward_res = for_in_if_net(x) |
|
|
|
graph_backward_res = net(x) |
|
|
|
|
|
|
|
# pynative mode |
|
|
|
context.set_context(mode=context.PYNATIVE_MODE) |
|
|
|
for_in_if_net = ForInIfNet() |
|
|
|
net = GradNet(for_in_if_net) |
|
|
|
pynative_forward_res = for_in_if_net(x) |
|
|
|
pynative_backward_res = net(x) |
|
|
|
|
|
|
|
assert graph_forward_res == pynative_forward_res |
|
|
|
assert graph_backward_res == pynative_backward_res |