|
|
|
@@ -42,7 +42,7 @@ def test_while_forward(): |
|
|
|
idx = idx + 1 |
|
|
|
return x |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
net = MyWhileNet() |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
end = Tensor(np.array(2), dtype=ms.int32) |
|
|
|
@@ -72,7 +72,7 @@ def test_while_grad(): |
|
|
|
def construct(self, *inputs): |
|
|
|
return C.grad_all(self.net)(*inputs) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
while_net = MyWhileNet() |
|
|
|
net = GradNet(while_net) |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -99,7 +99,7 @@ def test_while_with_param_forward(): |
|
|
|
idx = idx + 1 |
|
|
|
return out |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
net = MyWhileNet() |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
end = Tensor(np.array(2), dtype=ms.int32) |
|
|
|
@@ -124,7 +124,7 @@ def test_while_endless_case(): |
|
|
|
idx = idx + 1 |
|
|
|
return out |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
net = MyWhileNet() |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
end = Tensor(np.array(2), dtype=ms.int32) |
|
|
|
@@ -159,7 +159,7 @@ def test_while_with_param_grad(): |
|
|
|
def construct(self, a, b, c): |
|
|
|
return C.grad_by_list(self.net, self.weights)(a, b, c) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
while_net = MyWhileNet() |
|
|
|
net = GradNet(while_net) |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -187,7 +187,7 @@ def test_while_with_param_forward_with_const_branch(): |
|
|
|
idx = idx + 1 |
|
|
|
return out |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
while_net = MyWhileNet() |
|
|
|
net = while_net |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -224,7 +224,7 @@ def test_while_opt_endless(): |
|
|
|
def construct(self, *inputs): |
|
|
|
return C.grad_all(self.net)(*inputs) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
while_net = MyWhileNet() |
|
|
|
net = GradNet(while_net) |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -250,7 +250,7 @@ def test_no_while_call(): |
|
|
|
out = out + idx + self.param |
|
|
|
return out |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
while_net = MyWhileNet() |
|
|
|
net = while_net |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -287,7 +287,7 @@ def test_while_with_param_grad_with_const_branch(): |
|
|
|
def construct(self, a, b, c): |
|
|
|
return C.grad_by_list(self.net, self.weights)(a, b, c) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
while_net = MyWhileNet() |
|
|
|
net = GradNet(while_net) |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -327,7 +327,7 @@ def test_for_while_with_param_grad_with_const_branch(): |
|
|
|
def construct(self, a, b, c): |
|
|
|
return C.grad_by_list(self.net, self.weights)(a, b, c) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
while_net = MyWhileNet() |
|
|
|
net = GradNet(while_net) |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -364,7 +364,7 @@ def test_for_while_with_param_grad_basic(): |
|
|
|
def construct(self, a, b, c): |
|
|
|
return C.grad_by_list(self.net, self.weights)(a, b, c) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
while_net = MyWhileNet() |
|
|
|
net = GradNet(while_net) |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -401,7 +401,7 @@ def test_for_while_with_param_grad_normal(): |
|
|
|
def construct(self, a, b, c): |
|
|
|
return C.grad_by_list(self.net, self.weights)(a, b, c) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
while_net = MyWhileNet() |
|
|
|
net = GradNet(while_net) |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -435,7 +435,7 @@ def test_while_with_param_basic_grad(): |
|
|
|
def construct(self, a, b, c): |
|
|
|
return C.grad_by_list(self.net, self.weights)(a, b, c) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
while_net = MyWhileNet() |
|
|
|
net = GradNet(while_net) |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -469,7 +469,7 @@ def test_while_with_param_basic_grad_mul(): |
|
|
|
def construct(self, a, b, c): |
|
|
|
return C.grad_by_list(self.net, self.weights)(a, b, c) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
while_net = MyWhileNet() |
|
|
|
net = GradNet(while_net) |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -504,7 +504,7 @@ def test_while_with_param_basic_grad_two(): |
|
|
|
def construct(self, a, b, c): |
|
|
|
return C.grad_by_list(self.net, self.weights)(a, b, c) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
while_net = MyWhileNet() |
|
|
|
net = GradNet(while_net) |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -540,7 +540,7 @@ def test_while_with_param_basic_grad_three(): |
|
|
|
def construct(self, a, b, c): |
|
|
|
return C.grad_by_list(self.net, self.weights)(a, b, c) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
while_net = MyWhileNet() |
|
|
|
net = GradNet(while_net) |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -577,7 +577,7 @@ def test_while_if_with_param_grad(): |
|
|
|
def construct(self, a, b, c): |
|
|
|
return C.grad_by_list(self.net, self.weights)(a, b, c) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
while_net = MyWhileNet() |
|
|
|
net = GradNet(while_net) |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -610,7 +610,7 @@ def test_while_with_param_grad_not_enter_while(): |
|
|
|
def construct(self, a, b, c): |
|
|
|
return C.grad_by_list(self.net, self.weights)(a, b, c) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
while_net = MyWhileNet() |
|
|
|
net = GradNet(while_net) |
|
|
|
idx = Tensor(np.array(3), dtype=ms.int32) |
|
|
|
@@ -639,7 +639,7 @@ def test_with_param_if_by_if_forward(): |
|
|
|
out = out + x*2 |
|
|
|
return out |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
if_net = MyIfByIfNet() |
|
|
|
net = if_net |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -672,7 +672,7 @@ def test_with_param_if_by_if_grad_inputs(): |
|
|
|
def construct(self, *inputs): |
|
|
|
return C.grad_all(self.net)(*inputs) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
if_net = MyIfByIfNet() |
|
|
|
net = GradNet(if_net) |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -706,7 +706,7 @@ def test_with_param_if_by_if_grad_parameter(): |
|
|
|
def construct(self, *inputs): |
|
|
|
return C.grad_by_list(self.net, self.weights)(*inputs) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
if_net = MyIfByIfNet() |
|
|
|
net = GradNet(if_net) |
|
|
|
idx = Tensor(np.array(0), dtype=ms.int32) |
|
|
|
@@ -738,7 +738,7 @@ def test_with_param_if_by_if_grad_param_excute_null(): |
|
|
|
def construct(self, *inputs): |
|
|
|
return C.grad_by_list(self.net, self.weights)(*inputs) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
if_net = MyIfByIfNet() |
|
|
|
net = GradNet(if_net) |
|
|
|
idx = Tensor(np.array(4), dtype=ms.int32) |
|
|
|
@@ -772,7 +772,7 @@ def test_if_by_if_return_inside_grad(): |
|
|
|
def construct(self, *inputs): |
|
|
|
return C.grad_by_list(self.net, self.weights)(*inputs) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
if_net = MyIfByIfNet() |
|
|
|
net = GradNet(if_net) |
|
|
|
idx = Tensor(np.array(1), dtype=ms.int32) |
|
|
|
@@ -807,10 +807,342 @@ def test_if_by_if_forward(): |
|
|
|
out = a + b + x |
|
|
|
return out |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
if_net = MyIfByIfNet() |
|
|
|
net = if_net |
|
|
|
idx = Tensor(np.array(2), dtype=ms.float32) |
|
|
|
end = Tensor(np.array(3), dtype=ms.float32) |
|
|
|
x = Tensor(np.array(4), dtype=ms.float32) |
|
|
|
net(idx, end, x) |
|
|
|
|
|
|
|
|
|
|
|
def test_if_by_if_forward_control_tuple_switch(): |
|
|
|
"""tuple_get from swtich op will generate new switch inside to eliminate tuple_get""" |
|
|
|
class Branch3Net(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.add = P.TensorAdd() |
|
|
|
self.sub = P.Sub() |
|
|
|
self.mul = P.Mul() |
|
|
|
self.div = P.RealDiv() |
|
|
|
|
|
|
|
def construct(self, a, b, x): |
|
|
|
if b == x: |
|
|
|
b = self.add(a, b) |
|
|
|
else: |
|
|
|
b = self.add(a, x) |
|
|
|
return a, b, x |
|
|
|
class Branch2Net(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.add = P.TensorAdd() |
|
|
|
self.sub = P.Sub() |
|
|
|
self.mul = P.Mul() |
|
|
|
self.div = P.RealDiv() |
|
|
|
self.net = Branch3Net() |
|
|
|
|
|
|
|
def construct(self, a, b, x): |
|
|
|
if a == x: |
|
|
|
a = self.mul(a, b) |
|
|
|
else: |
|
|
|
a = self.div(a, b) |
|
|
|
return self.net(a, b, x) |
|
|
|
|
|
|
|
class MyIfByIfNet(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.add = P.TensorAdd() |
|
|
|
self.sub = P.Sub() |
|
|
|
self.mul = P.Mul() |
|
|
|
self.div = P.RealDiv() |
|
|
|
self.net = Branch2Net() |
|
|
|
|
|
|
|
def construct(self, a, b, x): |
|
|
|
if a < b: |
|
|
|
a = self.add(a, b) |
|
|
|
else: |
|
|
|
a = self.sub(a, b) |
|
|
|
a, b, x = self.net(a, b, x) |
|
|
|
a = a * b |
|
|
|
out = a + b + x |
|
|
|
return out |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
if_net = MyIfByIfNet() |
|
|
|
net = if_net |
|
|
|
idx = Tensor(np.array(2), dtype=ms.float32) |
|
|
|
end = Tensor(np.array(3), dtype=ms.float32) |
|
|
|
x = Tensor(np.array(0), dtype=ms.float32) |
|
|
|
net(idx, end, x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_if_by_if_forward_control_inside_net(): |
|
|
|
class Branch3Net(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.add = P.TensorAdd() |
|
|
|
self.sub = P.Sub() |
|
|
|
self.mul = P.Mul() |
|
|
|
self.div = P.RealDiv() |
|
|
|
|
|
|
|
def construct(self, a, b, x): |
|
|
|
if b == x: |
|
|
|
b = self.add(a, b) |
|
|
|
else: |
|
|
|
b = self.add(a, x) |
|
|
|
a = a * b |
|
|
|
out = a + b + x |
|
|
|
return out |
|
|
|
class Branch2Net(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.add = P.TensorAdd() |
|
|
|
self.sub = P.Sub() |
|
|
|
self.mul = P.Mul() |
|
|
|
self.div = P.RealDiv() |
|
|
|
self.net = Branch3Net() |
|
|
|
|
|
|
|
def construct(self, a, b, x): |
|
|
|
if a == x: |
|
|
|
a = self.mul(a, b) |
|
|
|
else: |
|
|
|
a = self.div(a, b) |
|
|
|
return self.net(a, b, x) |
|
|
|
|
|
|
|
class MyIfByIfNet(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.add = P.TensorAdd() |
|
|
|
self.sub = P.Sub() |
|
|
|
self.mul = P.Mul() |
|
|
|
self.div = P.RealDiv() |
|
|
|
self.net = Branch2Net() |
|
|
|
|
|
|
|
def construct(self, a, b, x): |
|
|
|
if a < b: |
|
|
|
a = self.add(a, b) |
|
|
|
else: |
|
|
|
a = self.sub(a, b) |
|
|
|
out = self.net(a, b, x) |
|
|
|
return out |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
if_net = MyIfByIfNet() |
|
|
|
net = if_net |
|
|
|
idx = Tensor(np.array(2), dtype=ms.float32) |
|
|
|
end = Tensor(np.array(3), dtype=ms.float32) |
|
|
|
x = Tensor(np.array(0), dtype=ms.float32) |
|
|
|
net(idx, end, x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_if_by_if_forward_use_namespace(): |
|
|
|
class MyIfByIfNet(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.add = P.TensorAdd() |
|
|
|
self.sub = P.Sub() |
|
|
|
self.mul = P.Mul() |
|
|
|
self.div = P.RealDiv() |
|
|
|
|
|
|
|
def construct(self, a, b, x): |
|
|
|
if a < b: |
|
|
|
a = P.TensorAdd()(a, b) |
|
|
|
else: |
|
|
|
a = P.Sub()(a, b) |
|
|
|
if a == x: |
|
|
|
a = P.Mul()(a, b) |
|
|
|
else: |
|
|
|
a = P.RealDiv()(a, b) |
|
|
|
if b == x: |
|
|
|
b = P.TensorAdd()(a, b) |
|
|
|
else: |
|
|
|
b = P.TensorAdd()(a, x) |
|
|
|
a = a * b |
|
|
|
out = a + b + x |
|
|
|
return out |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
if_net = MyIfByIfNet() |
|
|
|
net = if_net |
|
|
|
idx = Tensor(np.array(2), dtype=ms.float32) |
|
|
|
end = Tensor(np.array(3), dtype=ms.float32) |
|
|
|
x = Tensor(np.array(0), dtype=ms.float32) |
|
|
|
net(idx, end, x) |
|
|
|
|
|
|
|
|
|
|
|
def test_if_by_if_forward_use_global_op(): |
|
|
|
class MyIfByIfNet(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.add = P.TensorAdd() |
|
|
|
self.sub = P.Sub() |
|
|
|
self.mul = P.Mul() |
|
|
|
self.div = P.RealDiv() |
|
|
|
|
|
|
|
def construct(self, a, b, x): |
|
|
|
add = P.TensorAdd() |
|
|
|
sub = P.Sub() |
|
|
|
mul = P.Mul() |
|
|
|
div = P.RealDiv() |
|
|
|
if a < b: |
|
|
|
a = add(a, b) |
|
|
|
else: |
|
|
|
a = sub(a, b) |
|
|
|
if a == x: |
|
|
|
a = mul(a, b) |
|
|
|
else: |
|
|
|
a = div(a, b) |
|
|
|
if b == x: |
|
|
|
b = add(a, b) |
|
|
|
else: |
|
|
|
b = add(a, x) |
|
|
|
a = a * b |
|
|
|
out = a + b + x |
|
|
|
return out |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
if_net = MyIfByIfNet() |
|
|
|
net = if_net |
|
|
|
idx = Tensor(np.array(2), dtype=ms.float32) |
|
|
|
end = Tensor(np.array(3), dtype=ms.float32) |
|
|
|
x = Tensor(np.array(0), dtype=ms.float32) |
|
|
|
net(idx, end, x) |
|
|
|
|
|
|
|
|
|
|
|
def test_for_with_if_by_if_forward(): |
|
|
|
class MyIfByIfNet(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.add = P.TensorAdd() |
|
|
|
self.sub = P.Sub() |
|
|
|
|
|
|
|
def construct(self, a, b, x): |
|
|
|
for _ in range(0, 4): |
|
|
|
if a < b: |
|
|
|
a = self.add(a, b) |
|
|
|
else: |
|
|
|
b = self.sub(b, x) |
|
|
|
a = a * b |
|
|
|
out = a + b + x |
|
|
|
return out |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
if_net = MyIfByIfNet() |
|
|
|
net = if_net |
|
|
|
idx = Tensor(np.array(2), dtype=ms.float32) |
|
|
|
end = Tensor(np.array(3), dtype=ms.float32) |
|
|
|
x = Tensor(np.array(0), dtype=ms.float32) |
|
|
|
net(idx, end, x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_for_with_if_by_if_forward_namespace(): |
|
|
|
class MyIfByIfNet(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.add = P.TensorAdd() |
|
|
|
self.sub = P.Sub() |
|
|
|
self.mul = P.Mul() |
|
|
|
self.div = P.RealDiv() |
|
|
|
|
|
|
|
def construct(self, a, b, x): |
|
|
|
for _ in range(0, 6): |
|
|
|
if a < b: |
|
|
|
a = P.TensorAdd()(a, b) |
|
|
|
else: |
|
|
|
b = P.Sub()(b, x) |
|
|
|
a = a * b |
|
|
|
out = a + b + x |
|
|
|
return out |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
if_net = MyIfByIfNet() |
|
|
|
net = if_net |
|
|
|
idx = Tensor(np.array(2), dtype=ms.float32) |
|
|
|
end = Tensor(np.array(3), dtype=ms.float32) |
|
|
|
x = Tensor(np.array(0), dtype=ms.float32) |
|
|
|
net(idx, end, x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_if_by_if_forward_const_branch_inner(): |
|
|
|
class MyIfByIfNet(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.add = P.TensorAdd() |
|
|
|
self.sub = P.Sub() |
|
|
|
self.mul = P.Mul() |
|
|
|
self.div = P.RealDiv() |
|
|
|
|
|
|
|
def construct(self, a, b, x): |
|
|
|
add = P.TensorAdd() |
|
|
|
sub = P.Sub() |
|
|
|
mul = P.Mul() |
|
|
|
div = P.RealDiv() |
|
|
|
if a < b: |
|
|
|
a = add(a, b) |
|
|
|
else: |
|
|
|
a = sub(a, b) |
|
|
|
if 2 > 1: |
|
|
|
a = mul(a, b) |
|
|
|
else: |
|
|
|
a = div(a, b) |
|
|
|
if b == x: |
|
|
|
b = add(a, b) |
|
|
|
else: |
|
|
|
b = add(a, x) |
|
|
|
a = a * b |
|
|
|
out = a + b + x |
|
|
|
return out |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
if_net = MyIfByIfNet() |
|
|
|
net = if_net |
|
|
|
idx = Tensor(np.array(2), dtype=ms.float32) |
|
|
|
end = Tensor(np.array(3), dtype=ms.float32) |
|
|
|
x = Tensor(np.array(0), dtype=ms.float32) |
|
|
|
net(idx, end, x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_if_by_if_forward_all_const_branch(): |
|
|
|
class MyIfByIfNet(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.add = P.TensorAdd() |
|
|
|
self.sub = P.Sub() |
|
|
|
self.mul = P.Mul() |
|
|
|
self.div = P.RealDiv() |
|
|
|
|
|
|
|
def construct(self, a, b, x): |
|
|
|
add = P.TensorAdd() |
|
|
|
sub = P.Sub() |
|
|
|
mul = P.Mul() |
|
|
|
div = P.RealDiv() |
|
|
|
if 2 < 12: |
|
|
|
a = add(a, b) |
|
|
|
else: |
|
|
|
a = sub(a, b) |
|
|
|
if 2 > 1: |
|
|
|
a = mul(a, b) |
|
|
|
else: |
|
|
|
a = div(a, b) |
|
|
|
if 2 == 1: |
|
|
|
b = add(a, b) |
|
|
|
else: |
|
|
|
b = add(a, x) |
|
|
|
a = a * b |
|
|
|
out = a + b + x |
|
|
|
return out |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
if_net = MyIfByIfNet() |
|
|
|
net = if_net |
|
|
|
idx = Tensor(np.array(2), dtype=ms.float32) |
|
|
|
end = Tensor(np.array(3), dtype=ms.float32) |
|
|
|
x = Tensor(np.array(0), dtype=ms.float32) |
|
|
|
net(idx, end, x) |