|
|
|
@@ -21,10 +21,6 @@ from mindspore import Tensor |
|
|
|
from mindspore import context |
|
|
|
from mindspore.ops import operations as P |
|
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE) |
|
|
|
|
|
|
|
|
|
|
|
class Layer1(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super(Layer1, self).__init__() |
|
|
|
@@ -90,7 +86,29 @@ class MySwitchNet(nn.Cell): |
|
|
|
|
|
|
|
|
|
|
|
def test_layer_switch(): |
|
|
|
context.set_context(mode=context.GRAPH_MODE) |
|
|
|
net = MySwitchNet() |
|
|
|
x = Tensor(np.ones((3, 3, 24, 24)), mindspore.float32) |
|
|
|
index = Tensor(0, dtype=mindspore.int32) |
|
|
|
net(x, index) |
|
|
|
|
|
|
|
class MySwitchNetPynative(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super(MySwitchNetPynative, self).__init__() |
|
|
|
self.layer1 = Layer1() |
|
|
|
self.layer2 = Layer2() |
|
|
|
self.layer3 = Layer3() |
|
|
|
self.layers = (self.layer1, self.layer2, self.layer3) |
|
|
|
self.fill = P.Fill() |
|
|
|
|
|
|
|
def construct(self, x, index): |
|
|
|
return self.layers[index](x) |
|
|
|
|
|
|
|
|
|
|
|
def test_layer_switch_pynative(): |
|
|
|
context.set_context(mode=context.PYNATIVE_MODE) |
|
|
|
net = MySwitchNetPynative() |
|
|
|
x = Tensor(np.ones((3, 3, 24, 24)), mindspore.float32) |
|
|
|
index = Tensor(2, dtype=mindspore.int32) |
|
|
|
net(x, index) |
|
|
|
context.set_context(mode=context.GRAPH_MODE) |