|
|
|
@@ -45,3 +45,36 @@ def test_hyper_param(): |
|
|
|
output = net(x, y) |
|
|
|
output_expect = Tensor(39, ms.float32) |
|
|
|
assert output == output_expect |
|
|
|
|
|
|
|
|
|
|
|
def test_hyper_param_with_control_sink(): |
|
|
|
""" |
|
|
|
Feature: Resolve parameter. |
|
|
|
Description: Parameters whose name are the same between different graphs do not affect each other. |
|
|
|
Expectation: self.a is different from a in construct. |
|
|
|
""" |
|
|
|
class HyperParamNet(Cell): |
|
|
|
def __init__(self): |
|
|
|
super(HyperParamNet, self).__init__() |
|
|
|
self.a = Parameter(Tensor(1, ms.float32), name="a") |
|
|
|
self.b = Parameter(Tensor(5, ms.float32), name="b") |
|
|
|
self.c = Parameter(Tensor(9, ms.float32), name="c") |
|
|
|
|
|
|
|
def func_inner(self, c): |
|
|
|
return self.a + self.b + c |
|
|
|
|
|
|
|
def func_inner_2(self, a, c): |
|
|
|
return a - self.b + c |
|
|
|
|
|
|
|
def construct(self, a, b): |
|
|
|
self.b = b |
|
|
|
if a > self.b: |
|
|
|
return self.func_inner_2(a, self.c) |
|
|
|
return self.func_inner(self.c) |
|
|
|
|
|
|
|
x = Tensor(11, ms.float32) |
|
|
|
y = Tensor(19, ms.float32) |
|
|
|
net = HyperParamNet() |
|
|
|
output = net(x, y) |
|
|
|
output_expect = Tensor(29, ms.float32) |
|
|
|
assert output == output_expect |