You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cell.py 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test cell """
  16. import numpy as np
  17. import pytest
  18. import mindspore.context as context
  19. import mindspore.nn as nn
  20. from mindspore import Tensor, Parameter
  21. from mindspore.common.api import _executor
  22. from ..ut_filter import non_graph_engine
  23. class ModA(nn.Cell):
  24. def __init__(self, tensor):
  25. super(ModA, self).__init__()
  26. self.weight = Parameter(tensor, name="weight")
  27. def construct(self, *inputs):
  28. pass
  29. class ModB(nn.Cell):
  30. def __init__(self, tensor):
  31. super(ModB, self).__init__()
  32. self.weight = Parameter(tensor, name="weight")
  33. def construct(self, *inputs):
  34. pass
  35. class ModC(nn.Cell):
  36. def __init__(self, ta, tb):
  37. super(ModC, self).__init__()
  38. self.mod1 = ModA(ta)
  39. self.mod2 = ModB(tb)
  40. def construct(self, *inputs):
  41. pass
  42. class Net(nn.Cell):
  43. """ Net definition """
  44. name_len = 4
  45. cells_num = 3
  46. def __init__(self, ta, tb):
  47. super(Net, self).__init__()
  48. self.mod1 = ModA(ta)
  49. self.mod2 = ModB(tb)
  50. self.mod3 = ModC(ta, tb)
  51. def construct(self, *inputs):
  52. pass
  53. class Net2(nn.Cell):
  54. def __init__(self, ta, tb):
  55. super(Net2, self).__init__(auto_prefix=False)
  56. self.mod1 = ModA(ta)
  57. self.mod2 = ModB(tb)
  58. self.mod3 = ModC(ta, tb)
  59. def construct(self, *inputs):
  60. pass
  61. class ConvNet(nn.Cell):
  62. """ ConvNet definition """
  63. image_h = 224
  64. image_w = 224
  65. output_ch = 64
  66. def __init__(self, num_classes=10):
  67. super(ConvNet, self).__init__()
  68. self.conv1 = nn.Conv2d(3, ConvNet.output_ch, kernel_size=7, stride=2, pad_mode="pad", padding=3)
  69. self.bn1 = nn.BatchNorm2d(ConvNet.output_ch)
  70. self.relu = nn.ReLU()
  71. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
  72. self.flatten = nn.Flatten()
  73. self.fc = nn.Dense(
  74. int(ConvNet.image_h*ConvNet.image_w*ConvNet.output_ch/(4*4)),
  75. num_classes)
  76. def construct(self, x):
  77. x = self.conv1(x)
  78. x = self.bn1(x)
  79. x = self.relu(x)
  80. x = self.maxpool(x)
  81. x = self.flatten(x)
  82. x = self.fc(x)
  83. return x
  84. def test_basic():
  85. ta = Tensor(np.ones([2, 3]))
  86. tb = Tensor(np.ones([1, 4]))
  87. n = Net(ta, tb)
  88. names = list(n.parameters_dict().keys())
  89. assert len(names) == n.name_len
  90. assert names[0] == "mod1.weight"
  91. assert names[1] == "mod2.weight"
  92. assert names[2] == "mod3.mod1.weight"
  93. assert names[3] == "mod3.mod2.weight"
  94. def test_parameter_name():
  95. """ test_parameter_name """
  96. ta = Tensor(np.ones([2, 3]))
  97. tb = Tensor(np.ones([1, 4]))
  98. n = Net(ta, tb)
  99. names = []
  100. for m in n.parameters_and_names():
  101. if m[0]:
  102. names.append(m[0])
  103. assert names[0] == "mod1.weight"
  104. assert names[1] == "mod2.weight"
  105. assert names[2] == "mod3.mod1.weight"
  106. assert names[3] == "mod3.mod2.weight"
  107. def test_cell_name():
  108. """ test_cell_name """
  109. ta = Tensor(np.ones([2, 3]))
  110. tb = Tensor(np.ones([1, 4]))
  111. n = Net(ta, tb)
  112. n.insert_child_to_cell('modNone', None)
  113. names = []
  114. for m in n.cells_and_names():
  115. if m[0]:
  116. names.append(m[0])
  117. assert names[0] == "mod1"
  118. assert names[1] == "mod2"
  119. assert names[2] == "mod3"
  120. assert names[3] == "mod3.mod1"
  121. assert names[4] == "mod3.mod2"
  122. def test_cells():
  123. ta = Tensor(np.ones([2, 3]))
  124. tb = Tensor(np.ones([1, 4]))
  125. n = Net(ta, tb)
  126. ch = list(n.cells())
  127. assert len(ch) == n.cells_num
  128. def test_exceptions():
  129. """ test_exceptions """
  130. t = Tensor(np.ones([2, 3]))
  131. class ModError(nn.Cell):
  132. def __init__(self, tensor):
  133. self.weight = Parameter(tensor, name="weight")
  134. super(ModError, self).__init__()
  135. def construct(self, *inputs):
  136. pass
  137. with pytest.raises(AttributeError):
  138. ModError(t)
  139. class ModError1(nn.Cell):
  140. def __init__(self, tensor):
  141. super().__init__()
  142. self.weight = Parameter(tensor, name="weight")
  143. self.weight = None
  144. self.weight = ModA(tensor)
  145. def construct(self, *inputs):
  146. pass
  147. with pytest.raises(TypeError):
  148. ModError1(t)
  149. class ModError2(nn.Cell):
  150. def __init__(self, tensor):
  151. super().__init__()
  152. self.mod = ModA(tensor)
  153. self.mod = None
  154. self.mod = tensor
  155. def construct(self, *inputs):
  156. pass
  157. with pytest.raises(TypeError):
  158. ModError2(t)
  159. m = nn.Cell()
  160. with pytest.raises(NotImplementedError):
  161. m.construct()
  162. def test_del():
  163. """ test_del """
  164. ta = Tensor(np.ones([2, 3]))
  165. tb = Tensor(np.ones([1, 4]))
  166. n = Net(ta, tb)
  167. names = list(n.parameters_dict().keys())
  168. assert len(names) == n.name_len
  169. del n.mod1
  170. names = list(n.parameters_dict().keys())
  171. assert len(names) == n.name_len - 1
  172. with pytest.raises(AttributeError):
  173. del n.mod1.weight
  174. del n.mod2.weight
  175. names = list(n.parameters_dict().keys())
  176. assert len(names) == n.name_len - 2
  177. with pytest.raises(AttributeError):
  178. del n.mod
  179. def test_add_attr():
  180. """ test_add_attr """
  181. ta = Tensor(np.ones([2, 3]))
  182. tb = Tensor(np.ones([1, 4]))
  183. p = Parameter(ta, name="weight")
  184. m = nn.Cell()
  185. m.insert_param_to_cell('weight', p)
  186. with pytest.raises(TypeError):
  187. m.insert_child_to_cell("network", p)
  188. with pytest.raises(KeyError):
  189. m.insert_param_to_cell('', p)
  190. with pytest.raises(KeyError):
  191. m.insert_param_to_cell('a.b', p)
  192. m.insert_param_to_cell('weight', p)
  193. with pytest.raises(KeyError):
  194. m.insert_child_to_cell('', ModA(ta))
  195. with pytest.raises(KeyError):
  196. m.insert_child_to_cell('a.b', ModB(tb))
  197. with pytest.raises(TypeError):
  198. m.insert_child_to_cell('buffer', tb)
  199. with pytest.raises(TypeError):
  200. m.insert_param_to_cell('w', ta)
  201. with pytest.raises(TypeError):
  202. m.insert_child_to_cell('m', p)
  203. class ModAddCellError(nn.Cell):
  204. def __init__(self, tensor):
  205. self.mod = ModA(tensor)
  206. super().__init__()
  207. def construct(self, *inputs):
  208. pass
  209. with pytest.raises(AttributeError):
  210. ModAddCellError(ta)
  211. def test_train_eval():
  212. m = nn.Cell()
  213. assert not m.training
  214. m.set_train()
  215. assert m.training
  216. m.set_train(False)
  217. assert not m.training
  218. def test_stop_update_name():
  219. ta = Tensor(np.ones([2, 3]))
  220. tb = Tensor(np.ones([1, 4]))
  221. n = Net2(ta, tb)
  222. names = list(n.parameters_dict().keys())
  223. assert names[0] == "weight"
  224. assert names[1] == "mod1.weight"
  225. assert names[2] == "mod2.weight"
  226. class ModelName(nn.Cell):
  227. def __init__(self, tensor):
  228. super(ModelName, self).__init__()
  229. self.w2 = Parameter(tensor, name="weight")
  230. self.w1 = Parameter(tensor, name="weight")
  231. self.w3 = Parameter(tensor, name=None)
  232. self.w4 = Parameter(tensor, name=None)
  233. def construct(self, *inputs):
  234. pass
  235. def test_cell_names():
  236. ta = Tensor(np.ones([2, 3]))
  237. mn = ModelName(ta)
  238. with pytest.raises(ValueError):
  239. _executor.compile(mn)