Browse Source

!22245 Add pynative switch layer ut test case

Merge pull request !22245 from zjun/switch_layer_ut_test
tags/v1.5.0-rc1
i-robot Gitee 4 years ago
parent
commit
36238032a2
1 changed files with 22 additions and 4 deletions
  1. +22
    -4
      tests/ut/python/ops/test_layer_switch.py

+ 22
- 4
tests/ut/python/ops/test_layer_switch.py View File

@@ -21,10 +21,6 @@ from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P


context.set_context(mode=context.GRAPH_MODE)


class Layer1(nn.Cell):
def __init__(self):
super(Layer1, self).__init__()
@@ -90,7 +86,29 @@ class MySwitchNet(nn.Cell):


def test_layer_switch():
context.set_context(mode=context.GRAPH_MODE)
net = MySwitchNet()
x = Tensor(np.ones((3, 3, 24, 24)), mindspore.float32)
index = Tensor(0, dtype=mindspore.int32)
net(x, index)

class MySwitchNetPynative(nn.Cell):
def __init__(self):
super(MySwitchNetPynative, self).__init__()
self.layer1 = Layer1()
self.layer2 = Layer2()
self.layer3 = Layer3()
self.layers = (self.layer1, self.layer2, self.layer3)
self.fill = P.Fill()

def construct(self, x, index):
return self.layers[index](x)


def test_layer_switch_pynative():
context.set_context(mode=context.PYNATIVE_MODE)
net = MySwitchNetPynative()
x = Tensor(np.ones((3, 3, 24, 24)), mindspore.float32)
index = Tensor(2, dtype=mindspore.int32)
net(x, index)
context.set_context(mode=context.GRAPH_MODE)

Loading…
Cancel
Save