You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cell.py 8.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_cell """
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. from mindspore import Tensor, Parameter
  20. from ...ut_filter import non_graph_engine
  21. class ModA(nn.Cell):
  22. """ ModA definition """
  23. def __init__(self, tensor):
  24. super(ModA, self).__init__()
  25. self.weight = Parameter(tensor, name="weight")
  26. def construct(self, *inputs):
  27. pass
  28. class ModB(nn.Cell):
  29. """ ModB definition """
  30. def __init__(self, tensor):
  31. super(ModB, self).__init__()
  32. self.weight = Parameter(tensor, name="weight")
  33. def construct(self, *inputs):
  34. pass
  35. class ModC(nn.Cell):
  36. """ ModC definition """
  37. def __init__(self, ta, tb):
  38. super(ModC, self).__init__()
  39. self.mod1 = ModA(ta)
  40. self.mod2 = ModB(tb)
  41. def construct(self, *inputs):
  42. pass
  43. class Net(nn.Cell):
  44. """ Net definition """
  45. name_len = 4
  46. cells_num = 3
  47. def __init__(self, ta, tb):
  48. super(Net, self).__init__()
  49. self.mod1 = ModA(ta)
  50. self.mod2 = ModB(tb)
  51. self.mod3 = ModC(ta, tb)
  52. def construct(self, *inputs):
  53. pass
  54. class Net2(nn.Cell):
  55. """ Net2 definition """
  56. def __init__(self, ta, tb):
  57. super(Net2, self).__init__(auto_prefix=False)
  58. self.mod1 = ModA(ta)
  59. self.mod2 = ModB(tb)
  60. self.mod3 = ModC(ta, tb)
  61. def construct(self, *inputs):
  62. pass
  63. class ConvNet(nn.Cell):
  64. """ ConvNet definition """
  65. image_h = 224
  66. image_w = 224
  67. output_ch = 64
  68. def __init__(self, num_classes=10):
  69. super(ConvNet, self).__init__()
  70. self.conv1 = nn.Conv2d(3, ConvNet.output_ch, kernel_size=7, stride=2, pad_mode='pad', padding=3)
  71. self.bn1 = nn.BatchNorm2d(ConvNet.output_ch)
  72. self.relu = nn.ReLU()
  73. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
  74. self.flatten = nn.Flatten()
  75. self.fc = nn.Dense(
  76. int(ConvNet.image_h * ConvNet.image_w * ConvNet.output_ch / (4 * 4)),
  77. num_classes)
  78. def construct(self, x):
  79. x = self.conv1(x)
  80. x = self.bn1(x)
  81. x = self.relu(x)
  82. x = self.maxpool(x)
  83. x = self.flatten(x)
  84. x = self.fc(x)
  85. return x
  86. def test_basic():
  87. """ test_basic """
  88. ta = Tensor(np.ones([2, 3]))
  89. tb = Tensor(np.ones([1, 4]))
  90. n = Net(ta, tb)
  91. names = list(n.parameters_dict().keys())
  92. assert len(names) == n.name_len
  93. assert names[0] == "mod1.weight"
  94. assert names[1] == "mod2.weight"
  95. assert names[2] == "mod3.mod1.weight"
  96. assert names[3] == "mod3.mod2.weight"
  97. def test_parameter_name():
  98. """ test_parameter_name """
  99. ta = Tensor(np.ones([2, 3]))
  100. tb = Tensor(np.ones([1, 4]))
  101. n = Net(ta, tb)
  102. names = []
  103. for m in n.parameters_and_names():
  104. if m[0]:
  105. names.append(m[0])
  106. assert names[0] == "mod1.weight"
  107. assert names[1] == "mod2.weight"
  108. assert names[2] == "mod3.mod1.weight"
  109. assert names[3] == "mod3.mod2.weight"
  110. def test_cell_name():
  111. """ test_cell_name """
  112. ta = Tensor(np.ones([2, 3]))
  113. tb = Tensor(np.ones([1, 4]))
  114. n = Net(ta, tb)
  115. n.insert_child_to_cell('modNone', None)
  116. names = []
  117. for m in n.cells_and_names():
  118. if m[0]:
  119. names.append(m[0])
  120. assert names[0] == "mod1"
  121. assert names[1] == "mod2"
  122. assert names[2] == "mod3"
  123. assert names[3] == "mod3.mod1"
  124. assert names[4] == "mod3.mod2"
  125. def test_cells():
  126. """ test_cells """
  127. ta = Tensor(np.ones([2, 3]))
  128. tb = Tensor(np.ones([1, 4]))
  129. n = Net(ta, tb)
  130. ch = list(n.cells())
  131. assert len(ch) == n.cells_num
  132. def test_exceptions():
  133. """ test_exceptions """
  134. t = Tensor(np.ones([2, 3]))
  135. class ModError(nn.Cell):
  136. """ ModError definition """
  137. def __init__(self, tensor):
  138. self.weight = Parameter(tensor, name="weight")
  139. super(ModError, self).__init__()
  140. def construct(self, *inputs):
  141. pass
  142. with pytest.raises(AttributeError):
  143. ModError(t)
  144. class ModError1(nn.Cell):
  145. """ ModError1 definition """
  146. def __init__(self, tensor):
  147. super().__init__()
  148. self.weight = Parameter(tensor, name="weight")
  149. self.weight = None
  150. self.weight = ModA(tensor)
  151. def construct(self, *inputs):
  152. pass
  153. with pytest.raises(TypeError):
  154. ModError1(t)
  155. class ModError2(nn.Cell):
  156. """ ModError2 definition """
  157. def __init__(self, tensor):
  158. super().__init__()
  159. self.mod = ModA(tensor)
  160. self.mod = None
  161. self.mod = tensor
  162. def construct(self, *inputs):
  163. pass
  164. with pytest.raises(TypeError):
  165. ModError2(t)
  166. m = nn.Cell()
  167. assert m.construct() is None
  168. def test_del():
  169. """ test_del """
  170. ta = Tensor(np.ones([2, 3]))
  171. tb = Tensor(np.ones([1, 4]))
  172. n = Net(ta, tb)
  173. names = list(n.parameters_dict().keys())
  174. assert len(names) == n.name_len
  175. del n.mod1
  176. names = list(n.parameters_dict().keys())
  177. assert len(names) == n.name_len - 1
  178. with pytest.raises(AttributeError):
  179. del n.mod1.weight
  180. del n.mod2.weight
  181. names = list(n.parameters_dict().keys())
  182. assert len(names) == n.name_len - 2
  183. with pytest.raises(AttributeError):
  184. del n.mod
  185. def test_add_attr():
  186. """ test_add_attr """
  187. ta = Tensor(np.ones([2, 3]))
  188. tb = Tensor(np.ones([1, 4]))
  189. p = Parameter(ta, name="weight")
  190. m = nn.Cell()
  191. m.insert_param_to_cell('weight', p)
  192. with pytest.raises(TypeError):
  193. m.insert_child_to_cell("network", p)
  194. with pytest.raises(KeyError):
  195. m.insert_param_to_cell('', p)
  196. with pytest.raises(KeyError):
  197. m.insert_param_to_cell('a.b', p)
  198. m.insert_param_to_cell('weight', p)
  199. with pytest.raises(KeyError):
  200. m.insert_child_to_cell('', ModA(ta))
  201. with pytest.raises(KeyError):
  202. m.insert_child_to_cell('a.b', ModB(tb))
  203. with pytest.raises(TypeError):
  204. m.insert_child_to_cell('buffer', tb)
  205. with pytest.raises(TypeError):
  206. m.insert_param_to_cell('w', ta)
  207. with pytest.raises(TypeError):
  208. m.insert_child_to_cell('m', p)
  209. class ModAddCellError(nn.Cell):
  210. """ ModAddCellError definition """
  211. def __init__(self, tensor):
  212. self.mod = ModA(tensor)
  213. super().__init__()
  214. def construct(self, *inputs):
  215. pass
  216. with pytest.raises(AttributeError):
  217. ModAddCellError(ta)
  218. def test_train_eval():
  219. """ test_train_eval """
  220. m = nn.Cell()
  221. assert not m.training
  222. m.set_train()
  223. assert m.training
  224. m.set_train(False)
  225. assert not m.training
  226. def test_stop_update_name():
  227. """ test_stop_update_name """
  228. ta = Tensor(np.ones([2, 3]))
  229. tb = Tensor(np.ones([1, 4]))
  230. n = Net2(ta, tb)
  231. names = list(n.parameters_dict().keys())
  232. assert names[0] == "weight"
  233. assert names[1] == "mod1.weight"
  234. assert names[2] == "mod2.weight"
  235. @non_graph_engine
  236. def test_net_call():
  237. """ test_net_call """
  238. with pytest.raises(ValueError):
  239. net = ConvNet()
  240. input_x = Tensor(
  241. np.random.randint(0, 255, [1, 3, net.image_h, net.image_w]).astype(np.float32))
  242. net.construct(input_x)