You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_layer_switch.py 2.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import numpy as np
  2. import mindspore
  3. from mindspore import nn
  4. from mindspore import Tensor
  5. from mindspore import context
  6. from mindspore.ops import operations as P
  7. context.set_context(mode=context.GRAPH_MODE)
  8. class Layer1(nn.Cell):
  9. def __init__(self):
  10. super(Layer1, self).__init__()
  11. self.net = nn.Conv2d(3, 1, 3, pad_mode='same')
  12. self.pad = nn.Pad(
  13. paddings=((0, 0), (0, 2), (0, 0), (0, 0)), mode="CONSTANT")
  14. def construct(self, x):
  15. y = self.net(x)
  16. return self.pad(y)
  17. class Layer2(nn.Cell):
  18. def __init__(self):
  19. super(Layer2, self).__init__()
  20. self.net = nn.Conv2d(3, 1, 7, pad_mode='same')
  21. self.pad = nn.Pad(
  22. paddings=((0, 0), (0, 2), (0, 0), (0, 0)), mode="CONSTANT")
  23. def construct(self, x):
  24. y = self.net(x)
  25. return self.pad(y)
  26. class Layer3(nn.Cell):
  27. def __init__(self):
  28. super(Layer3, self).__init__()
  29. self.net = nn.Conv2d(3, 3, 3, pad_mode='same')
  30. def construct(self, x):
  31. return self.net(x)
  32. class SwitchNet(nn.Cell):
  33. def __init__(self):
  34. super(SwitchNet, self).__init__()
  35. self.layer1 = Layer1()
  36. self.layer2 = Layer2()
  37. self.layer3 = Layer3()
  38. self.layers = (self.layer1, self.layer2, self.layer3)
  39. self.fill = P.Fill()
  40. def construct(self, x, index):
  41. y = self.layers[index](x)
  42. return y
  43. class MySwitchNet(nn.Cell):
  44. def __init__(self):
  45. super(MySwitchNet, self).__init__()
  46. self.layer1 = Layer1()
  47. self.layer2 = Layer2()
  48. self.layer3 = Layer3()
  49. self.layers = (self.layer1, self.layer2, self.layer3)
  50. self.fill = P.Fill()
  51. def construct(self, x, index):
  52. y = self.layers[0](x)
  53. for i in range(len(self.layers)):
  54. if i == index:
  55. y = self.layers[i](x)
  56. return y
  57. def test_layer_switch():
  58. net = MySwitchNet()
  59. x = Tensor(np.ones((3, 3, 24, 24)), mindspore.float32)
  60. index = Tensor(0, dtype=mindspore.int32)
  61. y = net(x, index)