|
|
|
@@ -29,11 +29,13 @@ context.set_context(mode=context.GRAPH_MODE) |
|
|
|
def add_func(x, y): |
|
|
|
return x + y |
|
|
|
|
|
|
|
|
|
|
|
@ms_function |
|
|
|
def do_increment(i): |
|
|
|
add_1 = F.partial(add_func, 1) |
|
|
|
return add_1(i) |
|
|
|
|
|
|
|
|
|
|
|
def test_increment(): |
|
|
|
a = do_increment(9) |
|
|
|
assert a == 10 |
|
|
|
@@ -45,6 +47,7 @@ def use_monad(x, y): |
|
|
|
res = F.depend(res, monad.U) |
|
|
|
return res |
|
|
|
|
|
|
|
|
|
|
|
def test_use_monad(): |
|
|
|
x = Tensor(1.0, mstype.float32) |
|
|
|
y = Tensor(1.0, mstype.float32) |
|
|
|
@@ -62,6 +65,7 @@ class Net(nn.Cell): |
|
|
|
print(i) |
|
|
|
return x_len |
|
|
|
|
|
|
|
|
|
|
|
def test_builtins_len(): |
|
|
|
net = Net() |
|
|
|
net() |
|
|
|
@@ -75,6 +79,7 @@ def np_fallback_func(): |
|
|
|
me_x = me_x + me_x |
|
|
|
return me_x |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip(reason='Not support graph fallback feature yet') |
|
|
|
def test_np_fallback_func(): |
|
|
|
print(np_fallback_func()) |
|
|
|
@@ -88,6 +93,7 @@ def div_mod_func1(): |
|
|
|
a = divmod(x, y) |
|
|
|
return Tensor(a) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip(reason='Not support graph fallback feature yet') |
|
|
|
def test_div_mod_func1(): |
|
|
|
print(div_mod_func1()) # (2, 2) |
|
|
|
@@ -99,6 +105,7 @@ def div_mod_func2(x, y): |
|
|
|
a = divmod(x, y) |
|
|
|
return Tensor(a) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip(reason='Not support graph fallback feature yet') |
|
|
|
def test_div_mod_func2_scalar(): |
|
|
|
""" |
|
|
|
@@ -108,6 +115,7 @@ def test_div_mod_func2_scalar(): |
|
|
|
""" |
|
|
|
print(div_mod_func2(8, 3)) # (2, 2) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip(reason='Not support graph fallback feature yet') |
|
|
|
def test_div_mod_func2_tensor(): |
|
|
|
""" |
|
|
|
@@ -129,6 +137,7 @@ def select_func(cond, x, y): |
|
|
|
output = x |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
def test_select_func(): |
|
|
|
cond = Tensor([True, False]) |
|
|
|
x = Tensor([2, 3], mstype.float32) |
|
|
|
@@ -147,6 +156,7 @@ def select_func2(cond, x, y): |
|
|
|
output = x |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
def test_select_func2(): |
|
|
|
cond = Tensor([True, False]) |
|
|
|
x = Tensor([2, 3], mstype.float32) |
|
|
|
@@ -160,7 +170,62 @@ def slice_func(a, b): |
|
|
|
a[1:3, ::] = b |
|
|
|
return a |
|
|
|
|
|
|
|
|
|
|
|
def test_slice_func(): |
|
|
|
a = Tensor(np.arange(60).reshape(3, 4, 5), dtype=mstype.float32) |
|
|
|
b = Tensor([1], dtype=mstype.float32) |
|
|
|
print(slice_func(a, b)) |
|
|
|
|
|
|
|
|
|
|
|
@ms_function |
|
|
|
def np_fallback_func_tensor_index(x): |
|
|
|
array_x = tuple([2, 3, 4, 5]) |
|
|
|
np_x = np.array(array_x).astype(np.float32) |
|
|
|
me_x = Tensor(np_x) |
|
|
|
me_x = me_x + me_x |
|
|
|
return me_x[x] |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip(reason='Not support graph fallback feature yet') |
|
|
|
def test_np_fallback_func_tensor_index(): |
|
|
|
""" |
|
|
|
Feature: Fallback feature: support Tensor index. |
|
|
|
Description: Fallback feature: support Tensor index. |
|
|
|
Expectation: Fallback feature: support Tensor index. |
|
|
|
""" |
|
|
|
x = Tensor(1, mstype.int32) |
|
|
|
output = np_fallback_func_tensor_index(x) |
|
|
|
output_expect = Tensor(6, mstype.float32) |
|
|
|
assert output == output_expect |
|
|
|
|
|
|
|
|
|
|
|
class ControlNet(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super(ControlNet, self).__init__() |
|
|
|
|
|
|
|
def inner_function_1(self, a, b): |
|
|
|
return a + b |
|
|
|
|
|
|
|
def inner_function_2(self, a, b): |
|
|
|
return a - b |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
a = Tensor(np.array(4), mstype.int32) |
|
|
|
b = Tensor(np.array(5), mstype.int32) |
|
|
|
if a + b > x: |
|
|
|
return self.inner_function_1(a, b) |
|
|
|
return self.inner_function_2(a, b) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip(reason='Not support graph fallback feature yet') |
|
|
|
def test_fallback_control_sink_tensor(): |
|
|
|
""" |
|
|
|
Feature: Fallback feature: support define Tensor in Class construct. |
|
|
|
Description: Fallback feature: support define Tensor in Class construct. |
|
|
|
Expectation: Fallback feature: support define Tensor in Class construct. |
|
|
|
""" |
|
|
|
x = Tensor(np.array(1), mstype.int32) |
|
|
|
net = ControlNet() |
|
|
|
output = net(x) |
|
|
|
output_expect = Tensor(9, mstype.int32) |
|
|
|
assert output == output_expect |