You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_stridedslice.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore as ms
  18. from mindspore import context, Tensor, Parameter
  19. from mindspore.common.api import _cell_graph_executor
  20. from mindspore.nn import Cell, TrainOneStepCell, Momentum
  21. from mindspore.ops import operations as P
  22. from parallel.utils.utils import ParallelValidator
  23. class Net(Cell):
  24. def __init__(self, weight, w2, begin, end, strides, strategy1=None, strategy2=None, is_parameter=True,
  25. begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0):
  26. super().__init__()
  27. self.mul = P.Mul().shard(strategy1)
  28. self.strided_slice = P.StridedSlice(begin_mask=begin_mask,
  29. end_mask=end_mask,
  30. ellipsis_mask=ellipsis_mask, new_axis_mask=new_axis_mask,
  31. shrink_axis_mask=shrink_axis_mask).shard(strategy2)
  32. if is_parameter:
  33. self.weight = Parameter(weight, "w1")
  34. else:
  35. self.weight = weight
  36. self.mul2 = P.Mul()
  37. self.weight2 = Parameter(w2, "w2")
  38. self.begin = begin
  39. self.end = end
  40. self.strides = strides
  41. def construct(self, x, b):
  42. out = self.strided_slice(self.weight, self.begin, self.end, self.strides)
  43. out = self.mul(x, out)
  44. out = self.mul2(out, self.weight2)
  45. return out
  46. class Net2(Cell):
  47. def __init__(self, weight2, begin, end, strides, strategy1=None, strategy2=None,
  48. begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0):
  49. super().__init__()
  50. self.mul = P.Mul().shard(strategy1)
  51. self.strided_slice = P.StridedSlice(begin_mask=begin_mask,
  52. end_mask=end_mask,
  53. ellipsis_mask=ellipsis_mask, new_axis_mask=new_axis_mask,
  54. shrink_axis_mask=shrink_axis_mask).shard(strategy2)
  55. self.weight2 = Parameter(weight2, "w2")
  56. self.begin = begin
  57. self.end = end
  58. self.strides = strides
  59. def construct(self, x, b):
  60. out = self.mul(x, self.weight2)
  61. out = self.strided_slice(out, self.begin, self.end, self.strides)
  62. return out
  63. _x1 = Tensor(np.ones([128, 64, 1]), dtype=ms.float32)
  64. _x2 = Tensor(np.ones([1, 64, 32, 32]), dtype=ms.float32)
  65. _x3 = Tensor(np.ones([64, 32]), dtype=ms.float32)
  66. _w1 = Tensor(np.ones([256, 64, 32]), dtype=ms.float32)
  67. _w2 = Tensor(np.ones([128, 64, 1]), dtype=ms.float32)
  68. _w3 = Tensor(np.ones([1, 64, 32, 32]), dtype=ms.float32)
  69. _b1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
  70. _b2 = Tensor(np.ones([1, 64, 32, 32]), dtype=ms.float32)
  71. def compile_net(net, _x1, _b1):
  72. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  73. train_net = TrainOneStepCell(net, optimizer)
  74. train_net.set_auto_parallel()
  75. train_net.set_train()
  76. _cell_graph_executor.compile(train_net, _x1, _b1)
  77. context.reset_auto_parallel_context()
  78. def compile_net_utils(net: Cell, *inputs):
  79. net.set_auto_parallel()
  80. net.set_train()
  81. phase, _ = _cell_graph_executor.compile(net, *inputs, auto_parallel_mode=True)
  82. context.reset_auto_parallel_context()
  83. return phase
  84. def test_stridedslice_no_fully_fetch_split_error():
  85. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  86. strategy1 = ((2, 2, 2), (2, 2, 2))
  87. strategy2 = ((2, 2, 2),)
  88. net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True)
  89. with pytest.raises(RuntimeError):
  90. compile_net(net, _x1, _b1)
  91. def test_stridedslice_strides_no_1_split_error():
  92. """
  93. Feature: distribute operator stridedslice in auto parallel mode.
  94. Description: test stridedslice with strides no 1 split in semi auto parallel.
  95. Expectation: compile error.
  96. """
  97. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  98. strategy1 = ((2, 2, 2), (2, 2, 2))
  99. strategy2 = ((1, 2, 2),)
  100. net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 2), strategy1, strategy2, is_parameter=True)
  101. with pytest.raises(RuntimeError):
  102. compile_net(net, _x1, _b1)
  103. def test_stridedslice_begin_size_smaller():
  104. """
  105. Feature: distribute operator stridedslice in auto parallel mode.
  106. Description: test stridedslice with begin size is smaller in semi auto parallel.
  107. Expectation: compile done without error.
  108. """
  109. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  110. strategy1 = ((1, 4, 1), (1, 4, 2))
  111. strategy2 = ((1, 4, 2),)
  112. net = Net(_w1, _w2, (0, 0), (128, 64), (1, 1), strategy1, strategy2, is_parameter=True)
  113. compile_net(net, _x1, _b1)
  114. def test_stridedslice_parameter():
  115. """
  116. Feature: distribute operator stridedslice in auto parallel mode.
  117. Description: test stridedslice of parameter in semi auto parallel.
  118. Expectation: compile done without error.
  119. """
  120. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  121. strategy1 = ((1, 4, 1), (1, 4, 2))
  122. strategy2 = ((1, 4, 2),)
  123. net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True)
  124. compile_net(net, _x1, _b1)
  125. def test_stridedslice_begin_mask_no_0_split_parameter():
  126. """
  127. Feature: distribute operator stridedslice in auto parallel mode.
  128. Description: test stridedslice with begin mask no 0 split in semi auto parallel.
  129. Expectation: compile done without error.
  130. """
  131. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  132. strategy1 = ((1, 4, 1), (1, 4, 2))
  133. strategy2 = ((1, 4, 2),)
  134. net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True, begin_mask=1)
  135. compile_net(net, _x1, _b1)
  136. def test_stridedslice_end_mask_no_0_parameter():
  137. """
  138. Feature: distribute operator stridedslice in auto parallel mode.
  139. Description: test stridedslice with end mask no 0 in semi auto parallel.
  140. Expectation: compile done without error.
  141. """
  142. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  143. strategy1 = ((1, 4, 1), (1, 4, 2))
  144. strategy2 = ((1, 4, 2),)
  145. net = Net(_w1, _w2, (127, 0, 0), (128, 63, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True,
  146. begin_mask=1, end_mask=2)
  147. compile_net(net, _x1, _b1)
  148. def test_stridedslice_ellipsis_mask_no_0_parameter():
  149. """
  150. Feature: distribute operator stridedslice in auto parallel mode.
  151. Description: test stridedslice with ellipsis mask no 0 in semi auto parallel.
  152. Expectation: compile done without error.
  153. """
  154. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  155. strategy1 = ((1, 4, 1), (1, 4, 2))
  156. strategy2 = ((1, 4, 2),)
  157. net = Net(_w1, _w2, (127, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True,
  158. begin_mask=1, end_mask=2, ellipsis_mask=4)
  159. compile_net(net, _x1, _b1)
  160. def test_stridedslice_new_axis_mask_no_0_parameter():
  161. """
  162. Feature: distribute operator stridedslice in auto parallel mode.
  163. Description: test stridedslice with new axis mask no 0 in semi auto parallel.
  164. Expectation: compile done without error.
  165. """
  166. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  167. strategy1 = ((1, 4, 2, 1), (1, 4, 2, 1))
  168. strategy2 = ((1, 1, 4),)
  169. net = Net(_w1, _w3, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True,
  170. new_axis_mask=1)
  171. compile_net(net, _x2, _b2)
  172. def test_stridedslice_shrink_axis_mask_no_0_parameter():
  173. """
  174. Feature: distribute operator stridedslice in auto parallel mode.
  175. Description: test stridedslice with shrink axis mask no 0 in semi auto parallel.
  176. Expectation: compile done without error.
  177. """
  178. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  179. strategy1 = ((1, 2), (1, 2))
  180. strategy2 = ((1, 4, 1),)
  181. net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True,
  182. shrink_axis_mask=1)
  183. compile_net(net, _x3, _b1)
  184. def test_stridedslice_tensor():
  185. """
  186. Feature: distribute operator stridedslice in auto parallel mode.
  187. Description: test stridedslice of tensor in semi auto parallel.
  188. Expectation: compile done without error.
  189. """
  190. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  191. strategy1 = ((1, 4, 1), (1, 4, 2))
  192. strategy2 = ((1, 4, 2),)
  193. net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False)
  194. compile_net(net, _x1, _b1)
  195. def test_stridedslice_begin_mask_no_0_tensor():
  196. """
  197. Feature: distribute operator stridedslice in auto parallel mode.
  198. Description: test stridedslice with begin mask no 0 in semi auto parallel.
  199. Expectation: compile done without error.
  200. """
  201. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  202. strategy1 = ((1, 4, 1), (1, 4, 2))
  203. strategy2 = ((1, 4, 2),)
  204. net = Net(_w1, _w2, (127, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False, begin_mask=1)
  205. compile_net(net, _x1, _b1)
  206. def test_stridedslice_end_mask_no_0_tensor():
  207. """
  208. Feature: distribute operator stridedslice in auto parallel mode.
  209. Description: test stridedslice with end mask no 0 in semi auto parallel.
  210. Expectation: compile done without error.
  211. """
  212. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  213. strategy1 = ((1, 4, 1), (1, 4, 2))
  214. strategy2 = ((1, 4, 2),)
  215. net = Net(_w1, _w2, (0, 0, 0), (128, 63, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False, end_mask=2)
  216. compile_net(net, _x1, _b1)
  217. def test_stridedslice_ellipsis_mask_no_0_tensor():
  218. """
  219. Feature: distribute operator stridedslice in auto parallel mode.
  220. Description: test stridedslice with ellipsis mask no 0 in semi auto parallel.
  221. Expectation: compile done without error.
  222. """
  223. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  224. strategy1 = ((1, 4, 1), (1, 4, 2))
  225. strategy2 = ((1, 4, 2),)
  226. net = Net(_w1, _w2, (127, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False,
  227. begin_mask=1, end_mask=2, ellipsis_mask=4)
  228. compile_net(net, _x1, _b1)
  229. def test_stridedslice_new_axis_mask_no_0_tensor():
  230. """
  231. Feature: distribute operator stridedslice in auto parallel mode.
  232. Description: test stridedslice with new axis mask no 0 in semi auto parallel.
  233. Expectation: compile done without error.
  234. """
  235. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  236. strategy1 = ((1, 4, 2, 1), (1, 4, 2, 1))
  237. strategy2 = ((1, 1, 4),)
  238. net = Net(_w1, _w3, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False,
  239. new_axis_mask=1)
  240. compile_net(net, _x2, _b2)
  241. def test_stridedslice_shrink_axis_mask_no_0_tensor():
  242. """
  243. Feature: distribute operator stridedslice in auto parallel mode.
  244. Description: test stridedslice with shrink axis mask no 0 in semi auto parallel.
  245. Expectation: compile done without error.
  246. """
  247. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  248. strategy1 = ((1, 2), (1, 2))
  249. strategy2 = ((1, 4, 1),)
  250. net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False,
  251. shrink_axis_mask=1)
  252. compile_net(net, _x3, _b1)
  253. def test_stridedslice_parameter_no_full_split():
  254. """
  255. Feature: distribute operator stridedslice in auto parallel mode.
  256. Description: test stridedslice with no full split in semi auto parallel.
  257. Expectation: compile done without error.
  258. """
  259. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  260. strategy1 = ((1, 4, 1), (1, 4, 2))
  261. strategy2 = ((1, 2, 2),)
  262. net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True)
  263. compile_net(net, _x1, _b1)
  264. def test_stridedslice_output():
  265. """
  266. Feature: distribute operator stridedslice in auto parallel mode.
  267. Description: test stridedslice of output in semi auto parallel.
  268. Expectation: compile done without error.
  269. """
  270. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  271. strategy1 = ((1, 8, 1), (1, 8, 1))
  272. strategy2 = ((1, 8, 1),)
  273. net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2)
  274. compile_net(net, _x1, _b1)
  275. def test_stridedslice_begin_mask_no_0_output():
  276. """
  277. Feature: distribute operator stridedslice in auto parallel mode.
  278. Description: test stridedslice with begin mask no 0 in semi auto parallel.
  279. Expectation: compile done without error.
  280. """
  281. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  282. strategy1 = ((1, 8, 1), (1, 8, 1))
  283. strategy2 = ((1, 8, 1),)
  284. net = Net2(_w2, (61, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2, begin_mask=1)
  285. compile_net(net, _x1, _b1)
  286. def test_stridedslice_end_mask_no_0_output():
  287. """
  288. Feature: distribute operator stridedslice in auto parallel mode.
  289. Description: test stridedslice with end mask no 0 in semi auto parallel.
  290. Expectation: compile done without error.
  291. """
  292. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  293. strategy1 = ((1, 8, 1), (1, 8, 1))
  294. strategy2 = ((1, 8, 1),)
  295. net = Net2(_w2, (0, 0, 0), (64, 63, 1), (1, 1, 1), strategy1, strategy2, end_mask=2)
  296. compile_net(net, _x1, _b1)
  297. def test_stridedslice_ellipsis_mask_no_0_output():
  298. """
  299. Feature: distribute operator stridedslice in auto parallel mode.
  300. Description: test stridedslice with ellipsis mask no 0 in semi auto parallel.
  301. Expectation: compile done without error.
  302. """
  303. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  304. strategy1 = ((1, 8, 1), (1, 8, 1))
  305. strategy2 = ((1, 8, 1),)
  306. net = Net2(_w2, (63, 0, 0), (64, 63, 1), (1, 1, 1), strategy1, strategy2,
  307. begin_mask=1, end_mask=2, ellipsis_mask=4)
  308. compile_net(net, _x1, _b1)
  309. def test_stridedslice_new_axis_mask_no_0_output():
  310. """
  311. Feature: distribute operator stridedslice in auto parallel mode.
  312. Description: test stridedslice with new axis mask no 0 in semi auto parallel.
  313. Expectation: compile done without error.
  314. """
  315. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  316. strategy1 = ((1, 8, 1), (1, 8, 1))
  317. strategy2 = ((8, 1, 1),)
  318. net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2, new_axis_mask=1)
  319. compile_net(net, _x1, _b1)
  320. def test_stridedslice_shrink_axis_mask_no_0_output():
  321. """
  322. Feature: distribute operator stridedslice in auto parallel mode.
  323. Description: test stridedslice with shrink axis mask no 0 in semi auto parallel.
  324. Expectation: compile done without error.
  325. """
  326. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  327. strategy1 = ((1, 8, 1), (1, 8, 1))
  328. strategy2 = ((1, 8, 1),)
  329. net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2, shrink_axis_mask=1)
  330. compile_net(net, _x1, _b1)
  331. def test_stridedslice_output_no_full_split():
  332. """
  333. Feature: distribute operator stridedslice in auto parallel mode.
  334. Description: test stridedslice with no full split in semi auto parallel.
  335. Expectation: compile done without error.
  336. """
  337. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  338. strategy1 = ((1, 8, 1), (1, 8, 1))
  339. strategy2 = ((1, 4, 1),)
  340. net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2)
  341. compile_net(net, _x1, _b1)
  342. def test_stridedslice_no_strategy():
  343. """
  344. Feature: distribute operator stridedslice in auto parallel mode.
  345. Description: test stridedslice with no strategy in semi auto parallel.
  346. Expectation: compile done without error.
  347. """
  348. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  349. strategy1 = ((1, 8, 1), (1, 8, 1))
  350. strategy2 = None
  351. net = Net2(_w2, (0, 0, 0), (128, 64, 1), (1, 1, 1), strategy1, strategy2)
  352. compile_net(net, _x1, _b1)
  353. def test_stridedslice_begin_mask_no_0_no_strategy():
  354. """
  355. Feature: distribute operator stridedslice in auto parallel mode.
  356. Description: test stridedslice with begin mask no 0 in auto parallel.
  357. Expectation: compile done without error.
  358. """
  359. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  360. strategy1 = ((1, 8, 1), (1, 8, 1))
  361. strategy2 = None
  362. net = Net2(_w2, (127, 0, 0), (128, 64, 1), (1, 1, 1), strategy1, strategy2, begin_mask=1)
  363. compile_net(net, _x1, _b1)
  364. def test_stridedslice_auto_parallel():
  365. """
  366. Feature: distribute operator stridedslice in auto parallel mode.
  367. Description: test stridedslice in auto parallel.
  368. Expectation: compile done without error.
  369. """
  370. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  371. net = Net2(_w2, (0, 0, 0), (32, 64, 1), (1, 1, 1))
  372. compile_net(net, _x1, _b1)
  373. def test_stridedslice_begin_mask_no_0_auto_parallel():
  374. """
  375. Feature: distribute operator stridedslice in auto parallel mode.
  376. Description: test stridedslice with begin mask no 0 in auto parallel.
  377. Expectation: compile done without error.
  378. """
  379. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  380. net = Net2(_w2, (29, 0, 0), (32, 64, 1), (1, 1, 1), begin_mask=1)
  381. compile_net(net, _x1, _b1)
  382. def test_stridedslice_layout():
  383. """
  384. Features: StridedSlice
  385. Description: validate layout and structure
  386. Expectation: No raise RuntimeError
  387. """
  388. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  389. strategy1 = ((1, 4, 1), (1, 4, 2))
  390. strategy2 = ((1, 4, 2),)
  391. net = Net(_w1, _w2, (127, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True,
  392. begin_mask=1, end_mask=2, ellipsis_mask=4)
  393. phase = compile_net_utils(net, _x1, _b1)
  394. validator = ParallelValidator(net, phase)
  395. # check layout
  396. features_expect_layout = ([4, 2], [-1, 1, 0], [256, 16, 16], 0, True, '')
  397. assert validator.check_parameter_layout('w1', features_expect_layout)
  398. # check attrs
  399. roi_expect_attrs = {'begin_mask': 1, 'end_mask': 2, 'ellipsis_mask': 4}
  400. assert validator.check_node_attrs('StridedSlice-1', roi_expect_attrs)
  401. # check inputs
  402. roi_expect_inputs = ['Load-0', 'out((127, 0, 0))', 'out((128, 64, 32))', 'out((1, 1, 1))']
  403. assert validator.check_node_inputs('StridedSlice-1', roi_expect_inputs)
  404. # check sub_graph
  405. sub_graph = {
  406. 'StridedSlice-1': ['Load-0', 'out((127, 0, 0))', 'out((128, 64, 32))', 'out((1, 1, 1))'],
  407. 'Mul-0': ['Reshape-1', 'StridedSlice-1'],
  408. 'AllGather-2': ['Reshape-2'],
  409. 'Split-1': ['AllGather-2'],
  410. 'TupleGetItem-3': ['Split-1', 0],
  411. 'TupleGetItem-4': ['Split-1', 1],
  412. 'TupleGetItem-5': ['Split-1', 2],
  413. 'TupleGetItem-6': ['Split-1', 3],
  414. 'MakeTuple-2': ['TupleGetItem-3', 'TupleGetItem-4', 'TupleGetItem-5', 'TupleGetItem-6'],
  415. 'Concat-1': ['MakeTuple-2']
  416. }
  417. assert validator.check_graph_structure(sub_graph)