You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cell.py 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test cell """
  16. import copy
  17. import numpy as np
  18. import pytest
  19. import mindspore as ms
  20. import mindspore.nn as nn
  21. from mindspore import Tensor, Parameter
  22. from mindspore.ops import operations as P
  23. from mindspore.common.api import _cell_graph_executor
  24. class ModA(nn.Cell):
  25. def __init__(self, tensor):
  26. super(ModA, self).__init__()
  27. self.weight = Parameter(tensor, name="weight")
  28. def construct(self, *inputs):
  29. pass
  30. class ModB(nn.Cell):
  31. def __init__(self, tensor):
  32. super(ModB, self).__init__()
  33. self.weight = Parameter(tensor, name="weight")
  34. def construct(self, *inputs):
  35. pass
  36. class ModC(nn.Cell):
  37. def __init__(self, ta, tb):
  38. super(ModC, self).__init__()
  39. self.mod1 = ModA(ta)
  40. self.mod2 = ModB(tb)
  41. def construct(self, *inputs):
  42. pass
  43. class Net(nn.Cell):
  44. """ Net definition """
  45. name_len = 4
  46. cells_num = 3
  47. def __init__(self, ta, tb):
  48. super(Net, self).__init__()
  49. self.mod1 = ModA(ta)
  50. self.mod2 = ModB(tb)
  51. self.mod3 = ModC(ta, tb)
  52. def construct(self, *inputs):
  53. pass
  54. class Net2(nn.Cell):
  55. def __init__(self, ta, tb):
  56. super(Net2, self).__init__(auto_prefix=False)
  57. self.mod1 = ModA(ta)
  58. self.mod2 = ModB(tb)
  59. self.mod3 = ModC(ta, tb)
  60. def construct(self, *inputs):
  61. pass
  62. class ConvNet(nn.Cell):
  63. """ ConvNet definition """
  64. image_h = 224
  65. image_w = 224
  66. output_ch = 64
  67. def __init__(self, num_classes=10):
  68. super(ConvNet, self).__init__()
  69. self.conv1 = nn.Conv2d(3, ConvNet.output_ch, kernel_size=7, stride=2, pad_mode="pad", padding=3)
  70. self.bn1 = nn.BatchNorm2d(ConvNet.output_ch)
  71. self.relu = nn.ReLU()
  72. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
  73. self.flatten = nn.Flatten()
  74. self.fc = nn.Dense(
  75. int(ConvNet.image_h * ConvNet.image_w * ConvNet.output_ch / (4 * 4)),
  76. num_classes)
  77. def construct(self, x):
  78. x = self.conv1(x)
  79. x = self.bn1(x)
  80. x = self.relu(x)
  81. x = self.maxpool(x)
  82. x = self.flatten(x)
  83. x = self.fc(x)
  84. return x
  85. def test_basic():
  86. ta = Tensor(np.ones([2, 3]))
  87. tb = Tensor(np.ones([1, 4]))
  88. n = Net(ta, tb)
  89. names = list(n.parameters_dict().keys())
  90. assert len(names) == n.name_len
  91. assert names[0] == "mod1.weight"
  92. assert names[1] == "mod2.weight"
  93. assert names[2] == "mod3.mod1.weight"
  94. assert names[3] == "mod3.mod2.weight"
  95. def test_parameter_name():
  96. """ test_parameter_name """
  97. ta = Tensor(np.ones([2, 3]))
  98. tb = Tensor(np.ones([1, 4]))
  99. n = Net(ta, tb)
  100. names = []
  101. for m in n.parameters_and_names():
  102. if m[0]:
  103. names.append(m[0])
  104. assert names[0] == "mod1.weight"
  105. assert names[1] == "mod2.weight"
  106. assert names[2] == "mod3.mod1.weight"
  107. assert names[3] == "mod3.mod2.weight"
  108. def test_cell_name():
  109. """ test_cell_name """
  110. ta = Tensor(np.ones([2, 3]))
  111. tb = Tensor(np.ones([1, 4]))
  112. n = Net(ta, tb)
  113. n.insert_child_to_cell('modNone', None)
  114. names = []
  115. for m in n.cells_and_names():
  116. if m[0]:
  117. names.append(m[0])
  118. assert names[0] == "mod1"
  119. assert names[1] == "mod2"
  120. assert names[2] == "mod3"
  121. assert names[3] == "mod3.mod1"
  122. assert names[4] == "mod3.mod2"
  123. def test_cells():
  124. ta = Tensor(np.ones([2, 3]))
  125. tb = Tensor(np.ones([1, 4]))
  126. n = Net(ta, tb)
  127. ch = list(n.cells())
  128. assert len(ch) == n.cells_num
  129. def test_exceptions():
  130. """ test_exceptions """
  131. t = Tensor(np.ones([2, 3]))
  132. class ModError(nn.Cell):
  133. def __init__(self, tensor):
  134. self.weight = Parameter(tensor, name="weight")
  135. super(ModError, self).__init__()
  136. def construct(self, *inputs):
  137. pass
  138. with pytest.raises(AttributeError):
  139. ModError(t)
  140. class ModError1(nn.Cell):
  141. def __init__(self, tensor):
  142. super().__init__()
  143. self.weight = Parameter(tensor, name="weight")
  144. self.weight = None
  145. self.weight = ModA(tensor)
  146. def construct(self, *inputs):
  147. pass
  148. with pytest.raises(TypeError):
  149. ModError1(t)
  150. class ModError2(nn.Cell):
  151. def __init__(self, tensor):
  152. super().__init__()
  153. self.mod = ModA(tensor)
  154. self.mod = None
  155. self.mod = tensor
  156. def construct(self, *inputs):
  157. pass
  158. with pytest.raises(TypeError):
  159. ModError2(t)
  160. m = nn.Cell()
  161. assert m.construct() is None
  162. def test_cell_copy():
  163. net = ConvNet()
  164. copy.deepcopy(net)
  165. def test_del():
  166. """ test_del """
  167. ta = Tensor(np.ones([2, 3]))
  168. tb = Tensor(np.ones([1, 4]))
  169. n = Net(ta, tb)
  170. names = list(n.parameters_dict().keys())
  171. assert len(names) == n.name_len
  172. del n.mod1
  173. names = list(n.parameters_dict().keys())
  174. assert len(names) == n.name_len - 1
  175. with pytest.raises(AttributeError):
  176. del n.mod1.weight
  177. del n.mod2.weight
  178. names = list(n.parameters_dict().keys())
  179. assert len(names) == n.name_len - 2
  180. with pytest.raises(AttributeError):
  181. del n.mod
  182. def test_add_attr():
  183. """ test_add_attr """
  184. ta = Tensor(np.ones([2, 3]))
  185. tb = Tensor(np.ones([1, 4]))
  186. p = Parameter(ta, name="weight")
  187. m = nn.Cell()
  188. m.insert_param_to_cell('weight', p)
  189. with pytest.raises(TypeError):
  190. m.insert_child_to_cell("network", p)
  191. with pytest.raises(KeyError):
  192. m.insert_param_to_cell('', p)
  193. with pytest.raises(KeyError):
  194. m.insert_param_to_cell('a.b', p)
  195. m.insert_param_to_cell('weight', p)
  196. with pytest.raises(KeyError):
  197. m.insert_child_to_cell('', ModA(ta))
  198. with pytest.raises(KeyError):
  199. m.insert_child_to_cell('a.b', ModB(tb))
  200. with pytest.raises(TypeError):
  201. m.insert_child_to_cell('buffer', tb)
  202. with pytest.raises(TypeError):
  203. m.insert_param_to_cell('w', ta)
  204. with pytest.raises(TypeError):
  205. m.insert_child_to_cell('m', p)
  206. class ModAddCellError(nn.Cell):
  207. def __init__(self, tensor):
  208. self.mod = ModA(tensor)
  209. super().__init__()
  210. def construct(self, *inputs):
  211. pass
  212. with pytest.raises(AttributeError):
  213. ModAddCellError(ta)
  214. def test_train_eval():
  215. m = nn.Cell()
  216. assert not m.training
  217. m.set_train()
  218. assert m.training
  219. m.set_train(False)
  220. assert not m.training
  221. def test_stop_update_name():
  222. ta = Tensor(np.ones([2, 3]))
  223. tb = Tensor(np.ones([1, 4]))
  224. n = Net2(ta, tb)
  225. names = list(n.parameters_dict().keys())
  226. assert names[0] == "weight"
  227. assert names[1] == "mod1.weight"
  228. assert names[2] == "mod2.weight"
  229. class ModelName(nn.Cell):
  230. def __init__(self, tensor):
  231. super(ModelName, self).__init__()
  232. self.w2 = Parameter(tensor, name="weight")
  233. self.w1 = Parameter(tensor, name="weight")
  234. self.w3 = Parameter(tensor, name=None)
  235. self.w4 = Parameter(tensor, name=None)
  236. def construct(self, *inputs):
  237. pass
  238. def test_cell_names():
  239. ta = Tensor(np.ones([2, 3]))
  240. mn = ModelName(ta)
  241. with pytest.raises(ValueError):
  242. _cell_graph_executor.compile(mn)
  243. class TestKwargsNet(nn.Cell):
  244. def __init__(self):
  245. super(TestKwargsNet, self).__init__()
  246. def construct(self, p1, p2, p3=False, p4=False):
  247. if p3:
  248. return p1
  249. if p4:
  250. return P.Add()(p1, p2)
  251. return p2
  252. def test_kwargs_default_value1():
  253. """
  254. Feature: Supports Cell kwargs inputs.
  255. Description: Pass kwargs.
  256. Expectation: No exception.
  257. """
  258. x = Tensor([[1], [2], [3]], ms.float32)
  259. y = Tensor([[4], [5], [6]], ms.float32)
  260. net = TestKwargsNet()
  261. res = net(x, y, p4=True)
  262. print(res)
  263. def test_kwargs_default_value2():
  264. """
  265. Feature: Supports Cell kwargs inputs.
  266. Description: Pass kwargs.
  267. Expectation: No exception.
  268. """
  269. # Tensor(np.array([1, 2, 3, 4]), ms.float32).reshape((1, 1, 2, 2))
  270. x = Tensor([[[[1.0, 2.0], [3.0, 4.0]]]], ms.float32)
  271. nn_op = nn.ResizeBilinear()
  272. res = nn_op(x, (4, 4), align_corners=True)
  273. print(res)