You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cell.py 8.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_cell """
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. from mindspore import Tensor, Parameter
  20. from ...ut_filter import non_graph_engine
  21. class ModA(nn.Cell):
  22. """ ModA definition """
  23. def __init__(self, tensor):
  24. super(ModA, self).__init__()
  25. self.weight = Parameter(tensor, name="weight")
  26. def construct(self, *inputs):
  27. pass
  28. class ModB(nn.Cell):
  29. """ ModB definition """
  30. def __init__(self, tensor):
  31. super(ModB, self).__init__()
  32. self.weight = Parameter(tensor, name="weight")
  33. def construct(self, *inputs):
  34. pass
  35. class ModC(nn.Cell):
  36. """ ModC definition """
  37. def __init__(self, ta, tb):
  38. super(ModC, self).__init__()
  39. self.mod1 = ModA(ta)
  40. self.mod2 = ModB(tb)
  41. def construct(self, *inputs):
  42. pass
  43. class Net(nn.Cell):
  44. """ Net definition """
  45. name_len = 4
  46. cells_num = 3
  47. def __init__(self, ta, tb):
  48. super(Net, self).__init__()
  49. self.mod1 = ModA(ta)
  50. self.mod2 = ModB(tb)
  51. self.mod3 = ModC(ta, tb)
  52. def construct(self, *inputs):
  53. pass
  54. class Net2(nn.Cell):
  55. """ Net2 definition """
  56. def __init__(self, ta, tb):
  57. super(Net2, self).__init__(auto_prefix=False)
  58. self.mod1 = ModA(ta)
  59. self.mod2 = ModB(tb)
  60. self.mod3 = ModC(ta, tb)
  61. def construct(self, *inputs):
  62. pass
  63. class ConvNet(nn.Cell):
  64. """ ConvNet definition """
  65. image_h = 224
  66. image_w = 224
  67. output_ch = 64
  68. def __init__(self, num_classes=10):
  69. super(ConvNet, self).__init__()
  70. self.conv1 = nn.Conv2d(3, ConvNet.output_ch, kernel_size=7, stride=2, pad_mode='pad', padding=3)
  71. self.bn1 = nn.BatchNorm2d(ConvNet.output_ch)
  72. self.relu = nn.ReLU()
  73. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
  74. self.flatten = nn.Flatten()
  75. self.fc = nn.Dense(
  76. int(ConvNet.image_h * ConvNet.image_w * ConvNet.output_ch / (4 * 4)),
  77. num_classes)
  78. def construct(self, x):
  79. x = self.conv1(x)
  80. x = self.bn1(x)
  81. x = self.relu(x)
  82. x = self.maxpool(x)
  83. x = self.flatten(x)
  84. x = self.fc(x)
  85. return x
  86. def test_basic():
  87. """ test_basic """
  88. ta = Tensor(np.ones([2, 3]))
  89. tb = Tensor(np.ones([1, 4]))
  90. n = Net(ta, tb)
  91. names = list(n.parameters_dict().keys())
  92. assert len(names) == n.name_len
  93. assert names[0] == "mod1.weight"
  94. assert names[1] == "mod2.weight"
  95. assert names[2] == "mod3.mod1.weight"
  96. assert names[3] == "mod3.mod2.weight"
  97. def test_parameter_name():
  98. """ test_parameter_name """
  99. ta = Tensor(np.ones([2, 3]))
  100. tb = Tensor(np.ones([1, 4]))
  101. n = Net(ta, tb)
  102. names = []
  103. for m in n.parameters_and_names():
  104. if m[0]:
  105. names.append(m[0])
  106. assert names[0] == "mod1.weight"
  107. assert names[1] == "mod2.weight"
  108. assert names[2] == "mod3.mod1.weight"
  109. assert names[3] == "mod3.mod2.weight"
  110. def test_cell_name():
  111. """ test_cell_name """
  112. ta = Tensor(np.ones([2, 3]))
  113. tb = Tensor(np.ones([1, 4]))
  114. n = Net(ta, tb)
  115. n.insert_child_to_cell('modNone', None)
  116. names = []
  117. for m in n.cells_and_names():
  118. if m[0]:
  119. names.append(m[0])
  120. assert names[0] == "mod1"
  121. assert names[1] == "mod2"
  122. assert names[2] == "mod3"
  123. assert names[3] == "mod3.mod1"
  124. assert names[4] == "mod3.mod2"
  125. def test_cells():
  126. """ test_cells """
  127. ta = Tensor(np.ones([2, 3]))
  128. tb = Tensor(np.ones([1, 4]))
  129. n = Net(ta, tb)
  130. ch = list(n.cells())
  131. assert len(ch) == n.cells_num
  132. def test_exceptions():
  133. """ test_exceptions """
  134. t = Tensor(np.ones([2, 3]))
  135. class ModError(nn.Cell):
  136. """ ModError definition """
  137. def __init__(self, tensor):
  138. self.weight = Parameter(tensor, name="weight")
  139. super(ModError, self).__init__()
  140. def construct(self, *inputs):
  141. pass
  142. with pytest.raises(AttributeError):
  143. ModError(t)
  144. class ModError1(nn.Cell):
  145. """ ModError1 definition """
  146. def __init__(self, tensor):
  147. super().__init__()
  148. self.weight = Parameter(tensor, name="weight")
  149. self.weight = None
  150. self.weight = ModA(tensor)
  151. def construct(self, *inputs):
  152. pass
  153. with pytest.raises(TypeError):
  154. ModError1(t)
  155. class ModError2(nn.Cell):
  156. """ ModError2 definition """
  157. def __init__(self, tensor):
  158. super().__init__()
  159. self.mod = ModA(tensor)
  160. self.mod = None
  161. self.mod = tensor
  162. def construct(self, *inputs):
  163. pass
  164. with pytest.raises(TypeError):
  165. ModError2(t)
  166. m = nn.Cell()
  167. with pytest.raises(NotImplementedError):
  168. m.construct()
  169. def test_del():
  170. """ test_del """
  171. ta = Tensor(np.ones([2, 3]))
  172. tb = Tensor(np.ones([1, 4]))
  173. n = Net(ta, tb)
  174. names = list(n.parameters_dict().keys())
  175. assert len(names) == n.name_len
  176. del n.mod1
  177. names = list(n.parameters_dict().keys())
  178. assert len(names) == n.name_len - 1
  179. with pytest.raises(AttributeError):
  180. del n.mod1.weight
  181. del n.mod2.weight
  182. names = list(n.parameters_dict().keys())
  183. assert len(names) == n.name_len - 2
  184. with pytest.raises(AttributeError):
  185. del n.mod
  186. def test_add_attr():
  187. """ test_add_attr """
  188. ta = Tensor(np.ones([2, 3]))
  189. tb = Tensor(np.ones([1, 4]))
  190. p = Parameter(ta, name="weight")
  191. m = nn.Cell()
  192. m.insert_param_to_cell('weight', p)
  193. with pytest.raises(TypeError):
  194. m.insert_child_to_cell("network", p)
  195. with pytest.raises(KeyError):
  196. m.insert_param_to_cell('', p)
  197. with pytest.raises(KeyError):
  198. m.insert_param_to_cell('a.b', p)
  199. m.insert_param_to_cell('weight', p)
  200. with pytest.raises(KeyError):
  201. m.insert_child_to_cell('', ModA(ta))
  202. with pytest.raises(KeyError):
  203. m.insert_child_to_cell('a.b', ModB(tb))
  204. with pytest.raises(TypeError):
  205. m.insert_child_to_cell('buffer', tb)
  206. with pytest.raises(TypeError):
  207. m.insert_param_to_cell('w', ta)
  208. with pytest.raises(TypeError):
  209. m.insert_child_to_cell('m', p)
  210. class ModAddCellError(nn.Cell):
  211. """ ModAddCellError definition """
  212. def __init__(self, tensor):
  213. self.mod = ModA(tensor)
  214. super().__init__()
  215. def construct(self, *inputs):
  216. pass
  217. with pytest.raises(AttributeError):
  218. ModAddCellError(ta)
  219. def test_train_eval():
  220. """ test_train_eval """
  221. m = nn.Cell()
  222. assert not m.training
  223. m.set_train()
  224. assert m.training
  225. m.set_train(False)
  226. assert not m.training
  227. def test_stop_update_name():
  228. """ test_stop_update_name """
  229. ta = Tensor(np.ones([2, 3]))
  230. tb = Tensor(np.ones([1, 4]))
  231. n = Net2(ta, tb)
  232. names = list(n.parameters_dict().keys())
  233. assert names[0] == "weight"
  234. assert names[1] == "mod1.weight"
  235. assert names[2] == "mod2.weight"
  236. @non_graph_engine
  237. def test_net_call():
  238. """ test_net_call """
  239. with pytest.raises(ValueError):
  240. net = ConvNet()
  241. input_x = Tensor(
  242. np.random.randint(0, 255, [1, 3, net.image_h, net.image_w]).astype(np.float32))
  243. net.construct(input_x)