Browse Source

!25511 Add some fallback testcases

Merge pull request !25511 from Margaret_wangrui/fallback_tests
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
a7834c382f
1 changed files with 65 additions and 0 deletions
  1. +65
    -0
      tests/ut/python/fallback/test_graph_fallback.py

+ 65
- 0
tests/ut/python/fallback/test_graph_fallback.py View File

@@ -29,11 +29,13 @@ context.set_context(mode=context.GRAPH_MODE)
def add_func(x, y):
return x + y


@ms_function
def do_increment(i):
add_1 = F.partial(add_func, 1)
return add_1(i)


def test_increment():
a = do_increment(9)
assert a == 10
@@ -45,6 +47,7 @@ def use_monad(x, y):
res = F.depend(res, monad.U)
return res


def test_use_monad():
x = Tensor(1.0, mstype.float32)
y = Tensor(1.0, mstype.float32)
@@ -62,6 +65,7 @@ class Net(nn.Cell):
print(i)
return x_len


def test_builtins_len():
net = Net()
net()
@@ -75,6 +79,7 @@ def np_fallback_func():
me_x = me_x + me_x
return me_x


@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_fallback_func():
print(np_fallback_func())
@@ -88,6 +93,7 @@ def div_mod_func1():
a = divmod(x, y)
return Tensor(a)


@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_div_mod_func1():
print(div_mod_func1()) # (2, 2)
@@ -99,6 +105,7 @@ def div_mod_func2(x, y):
a = divmod(x, y)
return Tensor(a)


@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_div_mod_func2_scalar():
"""
@@ -108,6 +115,7 @@ def test_div_mod_func2_scalar():
"""
print(div_mod_func2(8, 3)) # (2, 2)


@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_div_mod_func2_tensor():
"""
@@ -129,6 +137,7 @@ def select_func(cond, x, y):
output = x
return output


def test_select_func():
cond = Tensor([True, False])
x = Tensor([2, 3], mstype.float32)
@@ -147,6 +156,7 @@ def select_func2(cond, x, y):
output = x
return output


def test_select_func2():
cond = Tensor([True, False])
x = Tensor([2, 3], mstype.float32)
@@ -160,7 +170,62 @@ def slice_func(a, b):
a[1:3, ::] = b
return a


def test_slice_func():
a = Tensor(np.arange(60).reshape(3, 4, 5), dtype=mstype.float32)
b = Tensor([1], dtype=mstype.float32)
print(slice_func(a, b))


@ms_function
def np_fallback_func_tensor_index(x):
array_x = tuple([2, 3, 4, 5])
np_x = np.array(array_x).astype(np.float32)
me_x = Tensor(np_x)
me_x = me_x + me_x
return me_x[x]


@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_fallback_func_tensor_index():
"""
Feature: Fallback feature: support Tensor index.
Description: Fallback feature: support Tensor index.
Expectation: Fallback feature: support Tensor index.
"""
x = Tensor(1, mstype.int32)
output = np_fallback_func_tensor_index(x)
output_expect = Tensor(6, mstype.float32)
assert output == output_expect


class ControlNet(nn.Cell):
def __init__(self):
super(ControlNet, self).__init__()

def inner_function_1(self, a, b):
return a + b

def inner_function_2(self, a, b):
return a - b

def construct(self, x):
a = Tensor(np.array(4), mstype.int32)
b = Tensor(np.array(5), mstype.int32)
if a + b > x:
return self.inner_function_1(a, b)
return self.inner_function_2(a, b)


@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_fallback_control_sink_tensor():
"""
Feature: Fallback feature: support define Tensor in Class construct.
Description: Fallback feature: support define Tensor in Class construct.
Expectation: Fallback feature: support define Tensor in Class construct.
"""
x = Tensor(np.array(1), mstype.int32)
net = ControlNet()
output = net(x)
output_expect = Tensor(9, mstype.int32)
assert output == output_expect

Loading…
Cancel
Save