You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cell.py 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test cell """
  16. import copy
  17. import numpy as np
  18. import pytest
  19. import mindspore.nn as nn
  20. from mindspore import Tensor, Parameter
  21. from mindspore.common.api import _executor
  22. class ModA(nn.Cell):
  23. def __init__(self, tensor):
  24. super(ModA, self).__init__()
  25. self.weight = Parameter(tensor, name="weight")
  26. def construct(self, *inputs):
  27. pass
  28. class ModB(nn.Cell):
  29. def __init__(self, tensor):
  30. super(ModB, self).__init__()
  31. self.weight = Parameter(tensor, name="weight")
  32. def construct(self, *inputs):
  33. pass
  34. class ModC(nn.Cell):
  35. def __init__(self, ta, tb):
  36. super(ModC, self).__init__()
  37. self.mod1 = ModA(ta)
  38. self.mod2 = ModB(tb)
  39. def construct(self, *inputs):
  40. pass
  41. class Net(nn.Cell):
  42. """ Net definition """
  43. name_len = 4
  44. cells_num = 3
  45. def __init__(self, ta, tb):
  46. super(Net, self).__init__()
  47. self.mod1 = ModA(ta)
  48. self.mod2 = ModB(tb)
  49. self.mod3 = ModC(ta, tb)
  50. def construct(self, *inputs):
  51. pass
  52. class Net2(nn.Cell):
  53. def __init__(self, ta, tb):
  54. super(Net2, self).__init__(auto_prefix=False)
  55. self.mod1 = ModA(ta)
  56. self.mod2 = ModB(tb)
  57. self.mod3 = ModC(ta, tb)
  58. def construct(self, *inputs):
  59. pass
  60. class ConvNet(nn.Cell):
  61. """ ConvNet definition """
  62. image_h = 224
  63. image_w = 224
  64. output_ch = 64
  65. def __init__(self, num_classes=10):
  66. super(ConvNet, self).__init__()
  67. self.conv1 = nn.Conv2d(3, ConvNet.output_ch, kernel_size=7, stride=2, pad_mode="pad", padding=3)
  68. self.bn1 = nn.BatchNorm2d(ConvNet.output_ch)
  69. self.relu = nn.ReLU()
  70. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
  71. self.flatten = nn.Flatten()
  72. self.fc = nn.Dense(
  73. int(ConvNet.image_h * ConvNet.image_w * ConvNet.output_ch / (4 * 4)),
  74. num_classes)
  75. def construct(self, x):
  76. x = self.conv1(x)
  77. x = self.bn1(x)
  78. x = self.relu(x)
  79. x = self.maxpool(x)
  80. x = self.flatten(x)
  81. x = self.fc(x)
  82. return x
  83. def test_basic():
  84. ta = Tensor(np.ones([2, 3]))
  85. tb = Tensor(np.ones([1, 4]))
  86. n = Net(ta, tb)
  87. names = list(n.parameters_dict().keys())
  88. assert len(names) == n.name_len
  89. assert names[0] == "mod1.weight"
  90. assert names[1] == "mod2.weight"
  91. assert names[2] == "mod3.mod1.weight"
  92. assert names[3] == "mod3.mod2.weight"
  93. def test_parameter_name():
  94. """ test_parameter_name """
  95. ta = Tensor(np.ones([2, 3]))
  96. tb = Tensor(np.ones([1, 4]))
  97. n = Net(ta, tb)
  98. names = []
  99. for m in n.parameters_and_names():
  100. if m[0]:
  101. names.append(m[0])
  102. assert names[0] == "mod1.weight"
  103. assert names[1] == "mod2.weight"
  104. assert names[2] == "mod3.mod1.weight"
  105. assert names[3] == "mod3.mod2.weight"
  106. def test_cell_name():
  107. """ test_cell_name """
  108. ta = Tensor(np.ones([2, 3]))
  109. tb = Tensor(np.ones([1, 4]))
  110. n = Net(ta, tb)
  111. n.insert_child_to_cell('modNone', None)
  112. names = []
  113. for m in n.cells_and_names():
  114. if m[0]:
  115. names.append(m[0])
  116. assert names[0] == "mod1"
  117. assert names[1] == "mod2"
  118. assert names[2] == "mod3"
  119. assert names[3] == "mod3.mod1"
  120. assert names[4] == "mod3.mod2"
  121. def test_cells():
  122. ta = Tensor(np.ones([2, 3]))
  123. tb = Tensor(np.ones([1, 4]))
  124. n = Net(ta, tb)
  125. ch = list(n.cells())
  126. assert len(ch) == n.cells_num
  127. def test_exceptions():
  128. """ test_exceptions """
  129. t = Tensor(np.ones([2, 3]))
  130. class ModError(nn.Cell):
  131. def __init__(self, tensor):
  132. self.weight = Parameter(tensor, name="weight")
  133. super(ModError, self).__init__()
  134. def construct(self, *inputs):
  135. pass
  136. with pytest.raises(AttributeError):
  137. ModError(t)
  138. class ModError1(nn.Cell):
  139. def __init__(self, tensor):
  140. super().__init__()
  141. self.weight = Parameter(tensor, name="weight")
  142. self.weight = None
  143. self.weight = ModA(tensor)
  144. def construct(self, *inputs):
  145. pass
  146. with pytest.raises(TypeError):
  147. ModError1(t)
  148. class ModError2(nn.Cell):
  149. def __init__(self, tensor):
  150. super().__init__()
  151. self.mod = ModA(tensor)
  152. self.mod = None
  153. self.mod = tensor
  154. def construct(self, *inputs):
  155. pass
  156. with pytest.raises(TypeError):
  157. ModError2(t)
  158. m = nn.Cell()
  159. with pytest.raises(NotImplementedError):
  160. m.construct()
  161. def test_cell_copy():
  162. net = ConvNet()
  163. copy.deepcopy(net)
  164. def test_del():
  165. """ test_del """
  166. ta = Tensor(np.ones([2, 3]))
  167. tb = Tensor(np.ones([1, 4]))
  168. n = Net(ta, tb)
  169. names = list(n.parameters_dict().keys())
  170. assert len(names) == n.name_len
  171. del n.mod1
  172. names = list(n.parameters_dict().keys())
  173. assert len(names) == n.name_len - 1
  174. with pytest.raises(AttributeError):
  175. del n.mod1.weight
  176. del n.mod2.weight
  177. names = list(n.parameters_dict().keys())
  178. assert len(names) == n.name_len - 2
  179. with pytest.raises(AttributeError):
  180. del n.mod
  181. def test_add_attr():
  182. """ test_add_attr """
  183. ta = Tensor(np.ones([2, 3]))
  184. tb = Tensor(np.ones([1, 4]))
  185. p = Parameter(ta, name="weight")
  186. m = nn.Cell()
  187. m.insert_param_to_cell('weight', p)
  188. with pytest.raises(TypeError):
  189. m.insert_child_to_cell("network", p)
  190. with pytest.raises(KeyError):
  191. m.insert_param_to_cell('', p)
  192. with pytest.raises(KeyError):
  193. m.insert_param_to_cell('a.b', p)
  194. m.insert_param_to_cell('weight', p)
  195. with pytest.raises(KeyError):
  196. m.insert_child_to_cell('', ModA(ta))
  197. with pytest.raises(KeyError):
  198. m.insert_child_to_cell('a.b', ModB(tb))
  199. with pytest.raises(TypeError):
  200. m.insert_child_to_cell('buffer', tb)
  201. with pytest.raises(TypeError):
  202. m.insert_param_to_cell('w', ta)
  203. with pytest.raises(TypeError):
  204. m.insert_child_to_cell('m', p)
  205. class ModAddCellError(nn.Cell):
  206. def __init__(self, tensor):
  207. self.mod = ModA(tensor)
  208. super().__init__()
  209. def construct(self, *inputs):
  210. pass
  211. with pytest.raises(AttributeError):
  212. ModAddCellError(ta)
  213. def test_train_eval():
  214. m = nn.Cell()
  215. assert not m.training
  216. m.set_train()
  217. assert m.training
  218. m.set_train(False)
  219. assert not m.training
  220. def test_stop_update_name():
  221. ta = Tensor(np.ones([2, 3]))
  222. tb = Tensor(np.ones([1, 4]))
  223. n = Net2(ta, tb)
  224. names = list(n.parameters_dict().keys())
  225. assert names[0] == "weight"
  226. assert names[1] == "mod1.weight"
  227. assert names[2] == "mod2.weight"
  228. class ModelName(nn.Cell):
  229. def __init__(self, tensor):
  230. super(ModelName, self).__init__()
  231. self.w2 = Parameter(tensor, name="weight")
  232. self.w1 = Parameter(tensor, name="weight")
  233. self.w3 = Parameter(tensor, name=None)
  234. self.w4 = Parameter(tensor, name=None)
  235. def construct(self, *inputs):
  236. pass
  237. def test_cell_names():
  238. ta = Tensor(np.ones([2, 3]))
  239. mn = ModelName(ta)
  240. with pytest.raises(ValueError):
  241. _executor.compile(mn)