You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_indexing_op.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import numpy as np
  11. import pytest
  12. import megengine.core.ops.builtin
  13. import megengine.core.tensor.raw_tensor
  14. from megengine.core.ops._internal import all_ops
  15. from megengine.core.tensor import Tensor
  16. from megengine.core.tensor.core import apply
  17. from megengine.core.tensor.raw_tensor import RawTensor, as_raw_tensor
  18. def cvt_to_shape_desc(val, inpvar, config=None):
  19. def as_tensor(val, device):
  20. assert device is not None, "can not infer device"
  21. # TODO: should copy to appropriate device
  22. val = as_raw_tensor(val, device=device)
  23. return val
  24. device = None
  25. if inpvar is not None:
  26. assert isinstance(inpvar, RawTensor)
  27. device = device or inpvar.device
  28. if config is not None:
  29. device = device or config.device
  30. if isinstance(val, RawTensor):
  31. return as_tensor(val, device)
  32. if not isinstance(val, collections.Iterable):
  33. val = [val]
  34. components = []
  35. on_host = True
  36. for i in val:
  37. if isinstance(i, RawTensor):
  38. on_host = False
  39. device = device or i.device
  40. else:
  41. assert isinstance(i, int), (
  42. "shape desc could contain either int or Tensor, got {}"
  43. " actually".format(repr(i))
  44. )
  45. components.append(i)
  46. assert components, "shape desc could not be empty"
  47. if on_host:
  48. shape = np.ascontiguousarray(components, dtype=np.int32)
  49. assert np.all(shape == components), "failed to convert to shape: {}".format(
  50. components
  51. )
  52. return as_tensor(shape, device)
  53. for idx, v in enumerate(components):
  54. if not isinstance(v, RawTensor):
  55. vi = int(v)
  56. assert vi == v, "could not convert {} to int".format(v)
  57. v = vi
  58. components[idx] = as_tensor(v, device)
  59. return invoke_op(all_oprs.Concat(axis=0), components)
  60. def canonize_reshape(inputs, *, config):
  61. src, tshape = inputs
  62. tshape = cvt_to_shape_desc(tshape, src, config)
  63. return src, tshape
  64. def canonize_inputs(inputs, *, config):
  65. """convert immediate numbers and SharedND to SymbolVar in inputs; at least
  66. one of the inputs must be SymbolVar, so comp node and comp graph can
  67. beinferred
  68. :return: list of converted vars
  69. """
  70. if (
  71. isinstance(inputs, (list, tuple))
  72. and len(inputs) == 1
  73. and isinstance(inputs[0], (list, tuple))
  74. ):
  75. # handle the case when a list is passed to a function with
  76. # variable-length argument (e.g. concat has signature concat(*inputs)
  77. # and is called with concat([a, b]))
  78. inputs = inputs[0]
  79. if isinstance(inputs, RawTensor):
  80. return [inputs]
  81. old_inputs = inputs
  82. inputs = []
  83. get_comp_node = None
  84. need_cvt = False
  85. for i in old_inputs:
  86. if isinstance(i, RawTensor):
  87. get_comp_node = lambda cn=i.device.to_c(): cn
  88. else:
  89. need_cvt = True
  90. inputs.append(i)
  91. if not need_cvt:
  92. return inputs
  93. if get_comp_node is None:
  94. def get_comp_node():
  95. return config.comp_node
  96. for idx, var in enumerate(inputs):
  97. if not isinstance(var, RawTensor):
  98. var = as_raw_tensor(var)
  99. inputs[idx] = var
  100. return inputs
  101. def invoke_op(op, inputs_, cvt_inputs=canonize_inputs):
  102. inputs = cvt_inputs(
  103. inputs_, config=megengine.core._imperative_rt.OperatorNodeConfig()
  104. )
  105. return apply(op, *inputs)
  106. def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
  107. assert isinstance(inp, RawTensor)
  108. if not isinstance(tuple_val, tuple):
  109. tuple_val = (tuple_val,)
  110. def as_tensor(v):
  111. if not isinstance(v, RawTensor):
  112. vi = np.ascontiguousarray(v, dtype=np.int32)
  113. assert np.abs(vi - v).max() == 0, "bad index: {!r}".format(v)
  114. v = as_raw_tensor(vi)
  115. return v
  116. new_axes = []
  117. tensors = []
  118. items = []
  119. cur_axis = -1
  120. for i_idx, i in enumerate(tuple_val):
  121. cur_axis += 1
  122. if i is np.newaxis:
  123. if cur_axis >= 0:
  124. new_axes.append(cur_axis)
  125. continue
  126. if i is Ellipsis:
  127. cur_axis = -1
  128. for j in tuple_val[:i_idx:-1]:
  129. if j is Ellipsis:
  130. raise IndexError("only one ellipsis is allowed")
  131. if j is np.newaxis:
  132. new_axes.append(cur_axis)
  133. cur_axis -= 1
  134. continue
  135. item = [
  136. cur_axis,
  137. ]
  138. def push(v, item, tensors):
  139. if v is None:
  140. item.append(False)
  141. else:
  142. item.append(True)
  143. tensors.append(as_tensor(v))
  144. if isinstance(i, slice):
  145. if i.start is None and i.stop is None and i.step is None:
  146. continue
  147. push(i.start, item, tensors)
  148. push(i.stop, item, tensors)
  149. push(i.step, item, tensors)
  150. item.append(False) # idx
  151. else:
  152. item += [False,] * 3 # begin, end, stop
  153. push(i, item, tensors)
  154. assert len(item) == 5
  155. items.append(item)
  156. if new_axes:
  157. raise IndexError("newaxis is not allowed here")
  158. return inp, tensors, items
  159. def dimshuffle(*args, **kwargs):
  160. op = all_ops.Dimshuffle(**kwargs).to_c()
  161. return invoke_op(op, args)
  162. def broadcast(input, tshape):
  163. op = all_ops.Broadcast().to_c()
  164. return invoke_op(op, (input, tshape), canonize_reshape)
  165. def subtensor(input, tuple_val):
  166. input, tensors, items = unpack_getitem(input, tuple_val)
  167. op = all_ops.Subtensor(items).to_c()
  168. return invoke_op(op, (input, *tensors))
  169. def set_subtensor(input, value, tuple_val):
  170. input, tensors, items = unpack_getitem(input, tuple_val)
  171. op = all_ops.SetSubtensor(items).to_c()
  172. return invoke_op(op, (input, value, *tensors))
  173. def incr_subtensor(input, value, tuple_val):
  174. input, tensors, items = unpack_getitem(input, tuple_val)
  175. op = all_ops.IncrSubtensor(items).to_c()
  176. return invoke_op(op, (input, value, *tensors))
  177. def advance_indexing(input, tuple_val):
  178. input, tensors, items = unpack_getitem(input, tuple_val)
  179. op = all_ops.IndexingMultiAxisVec(items).to_c()
  180. return invoke_op(op, (input, *tensors))
  181. def set_advance_indexing(input, value, tuple_val):
  182. input, tensors, items = unpack_getitem(input, tuple_val)
  183. op = all_ops.IndexingSetMultiAxisVec(items).to_c()
  184. return invoke_op(op, (input, value, *tensors))
  185. def incr_advance_indexing(input, value, tuple_val):
  186. input, tensors, items = unpack_getitem(input, tuple_val)
  187. op = all_ops.IndexingIncrMultiAxisVec(items).to_c()
  188. return invoke_op(op, (input, value, *tensors))
  189. def mesh_indexing(input, tuple_val):
  190. input, tensors, items = unpack_getitem(input, tuple_val)
  191. op = all_ops.MeshIndexing(items).to_c()
  192. return invoke_op(op, (input, *tensors))
  193. def set_mesh_indexing(input, value, tuple_val):
  194. input, tensors, items = unpack_getitem(input, tuple_val)
  195. op = all_ops.SetMeshIndexing(items).to_c()
  196. return invoke_op(op, (input, value, *tensors))
  197. def incr_mesh_indexing(input, value, tuple_val):
  198. input, tensors, items = unpack_getitem(input, tuple_val)
  199. op = all_ops.IncrMeshIndexing(items).to_c()
  200. return invoke_op(op, (input, value, *tensors))
  201. def batched_mesh_indexing(input, tuple_val):
  202. input, tensors, items = unpack_getitem(input, tuple_val)
  203. op = all_ops.BatchedMeshIndexing(items).to_c()
  204. return invoke_op(op, (input, *tensors))
  205. def batched_set_mesh_indexing(input, value, tuple_val):
  206. input, tensors, items = unpack_getitem(input, tuple_val)
  207. op = all_ops.BatchedSetMeshIndexing(items).to_c()
  208. return invoke_op(op, (input, value, *tensors))
  209. def batched_incr_mesh_indexing(input, value, tuple_val):
  210. input, tensors, items = unpack_getitem(input, tuple_val)
  211. op = all_ops.BatchedIncrMeshIndexing(items).to_c()
  212. return invoke_op(op, (input, value, *tensors))
  213. def test_dimshuffle():
  214. x = np.arange(10).reshape(2, 5).astype("int32")
  215. xx = as_raw_tensor(x)
  216. (yy,) = dimshuffle(xx, pattern="1x0")
  217. np.testing.assert_equal(np.expand_dims(x.transpose(), axis=1), yy.numpy())
  218. def test_broadcast():
  219. x = np.arange(10).reshape(1, 10).astype("int32")
  220. xx = as_raw_tensor(x)
  221. (yy,) = broadcast(xx, (10, 10))
  222. np.testing.assert_equal(np.repeat(x, 10, 0), yy.numpy())
  223. def test_subtensor():
  224. x = np.arange(25).reshape(5, 5).astype("int32")
  225. d = np.arange(2).astype("int32")
  226. xx = as_raw_tensor(x)
  227. (yy0,) = subtensor(xx, (slice(0, 4, 2), 3))
  228. (yy1,) = set_subtensor(xx, d, (slice(0, 4, 2), 3))
  229. (yy2,) = incr_subtensor(xx, d, (slice(0, 4, 2), 3))
  230. np.testing.assert_equal(x[0:4:2, 3], yy0.numpy())
  231. x_ = x.copy()
  232. x_[0:4:2, 3] = d
  233. np.testing.assert_equal(x_, yy1.numpy())
  234. x_ = x.copy()
  235. x_[0:4:2, 3] += d
  236. np.testing.assert_equal(x_, yy2.numpy())
  237. def test_advance_indexing():
  238. x = np.arange(25).reshape(5, 5).astype("int32")
  239. d = np.arange(15).reshape(3, 5).astype("int32")
  240. xx = as_raw_tensor(x)
  241. (yy0,) = advance_indexing(xx, ((0, 4, 2), slice(None, None, None)))
  242. (yy1,) = set_advance_indexing(xx, d, ((0, 4, 2), slice(None, None, None)))
  243. (yy2,) = incr_advance_indexing(xx, d, ((0, 4, 2), slice(None, None, None)))
  244. np.testing.assert_equal(x[(0, 4, 2), :], yy0.numpy())
  245. x_ = x.copy()
  246. x_[(0, 4, 2), :] = d
  247. np.testing.assert_equal(x_, yy1.numpy())
  248. x_ = x.copy()
  249. x_[(0, 4, 2), :] += d
  250. np.testing.assert_equal(x_, yy2.numpy())
  251. def test_mesh_indexing():
  252. x = np.arange(25).reshape(5, 5).astype("int32")
  253. d = np.arange(6).reshape(3, 2).astype("int32")
  254. xx = as_raw_tensor(x)
  255. (yy0,) = mesh_indexing(xx, (slice(0, 5, 2), (1, 3)))
  256. (yy1,) = set_mesh_indexing(xx, d, (slice(0, 5, 2), (1, 3)))
  257. (yy2,) = incr_mesh_indexing(xx, d, (slice(0, 5, 2), (1, 3)))
  258. r = np.ndarray(shape=(3, 2), dtype="int32")
  259. for i0, i1 in enumerate(range(0, 5, 2)):
  260. for j0, j1 in enumerate((1, 3)):
  261. r[i0, j0] = x[i1, j1]
  262. np.testing.assert_equal(r, yy0.numpy())
  263. r = x.copy()
  264. for i0, i1 in enumerate(range(0, 5, 2)):
  265. for j0, j1 in enumerate((1, 3)):
  266. r[i1, j1] = d[i0, j0]
  267. np.testing.assert_equal(r, yy1.numpy())
  268. r = x.copy()
  269. for i0, i1 in enumerate(range(0, 5, 2)):
  270. for j0, j1 in enumerate((1, 3)):
  271. r[i1, j1] += d[i0, j0]
  272. np.testing.assert_equal(r, yy2.numpy())
  273. def test_batched_mesh_indexing():
  274. x = np.arange(24).reshape(2, 3, 4).astype("int32")
  275. d = np.arange(12).reshape(2, 2, 3).astype("int32")
  276. xx = as_raw_tensor(x)
  277. s = [(0, 1, 2), (1, 2, 3)]
  278. (yy0,) = batched_mesh_indexing(xx, (slice(None, None, None), [(0, 2)] * 2, s))
  279. (yy1,) = batched_set_mesh_indexing(
  280. xx, d, (slice(None, None, None), [(0, 2)] * 2, s)
  281. )
  282. (yy2,) = batched_incr_mesh_indexing(
  283. xx, d, (slice(None, None, None), [(0, 2)] * 2, s)
  284. )
  285. r = np.ndarray(shape=(2, 2, 3), dtype="int32")
  286. for i in range(2):
  287. for j0, j1 in enumerate((0, 2)):
  288. for k0, k1 in enumerate(s[i]):
  289. r[i, j0, k0] = x[i, j1, k1]
  290. np.testing.assert_equal(r, yy0.numpy())
  291. r = x.copy()
  292. for i in range(2):
  293. for j0, j1 in enumerate((0, 2)):
  294. for k0, k1 in enumerate(s[i]):
  295. r[i, j1, k1] = d[i, j0, k0]
  296. np.testing.assert_equal(r, yy1.numpy())
  297. r = x.copy()
  298. for i in range(2):
  299. for j0, j1 in enumerate((0, 2)):
  300. for k0, k1 in enumerate(s[i]):
  301. r[i, j1, k1] += d[i, j0, k0]
  302. np.testing.assert_equal(r, yy2.numpy())
  303. # high level
  304. def test_advance_indexing_high_level():
  305. x = np.arange(25).reshape(5, 5).astype("int32")
  306. d = np.arange(15).reshape(3, 5).astype("int32")
  307. xx = Tensor(x)
  308. np.testing.assert_equal(x[1, :], xx[1, :].numpy())
  309. np.testing.assert_equal(x[:, 1], xx[:, 1].numpy())
  310. np.testing.assert_equal(x[1:3, :], xx[1:3, :].numpy())
  311. np.testing.assert_equal(x[:, :], xx[:, :].numpy())
  312. np.testing.assert_equal(x[1, 1], xx[1, 1].numpy())
  313. yy = xx[(0, 4, 2), :]
  314. np.testing.assert_equal(x[(0, 4, 2), :], yy.numpy())
  315. x_ = x.copy()
  316. x_[(0, 4, 2), :] = d
  317. xx_ = Tensor(xx)
  318. xx_[(0, 4, 2), :] = d
  319. np.testing.assert_equal(x_, xx_.numpy())
  320. x = np.arange(27).reshape(3, 3, 3).astype("int32")
  321. xx = Tensor(x)
  322. np.testing.assert_equal(x[1, :, :], xx[1, :, :].numpy())
  323. np.testing.assert_equal(x[1, :, 1], xx[1, :, 1].numpy())
  324. np.testing.assert_equal(x[1, 0:1, :], xx[1, 0:1, :].numpy())
  325. np.testing.assert_equal(x[0:1, 1, 1], xx[0:1, 1, 1].numpy())
  326. np.testing.assert_equal(x[:, 1, 1], xx[:, 1, 1].numpy())
  327. np.testing.assert_equal(x[:, 1], xx[:, 1].numpy())
  328. np.testing.assert_equal(x[1, 1:2], xx[1, 1:2].numpy())
  329. x_ = x.copy()
  330. x_[1, 1, 1] = -1
  331. xx[1, 1, 1] = -1
  332. np.testing.assert_equal(x_, xx.numpy())
  333. x_[:, 1, 1] = -2
  334. xx[:, 1, 1] = x_[:, 1, 1]
  335. np.testing.assert_equal(x_, xx.numpy())
  336. x_[0:1, :, 1] = -3
  337. xx[0:1, :, 1] = x_[0:1, :, 1]
  338. np.testing.assert_equal(x_, xx.numpy())
  339. x_[0:1, :, 1] = -4
  340. y = Tensor(x_)
  341. xx[0:1, :, 1] = y[0:1, :, 1]
  342. np.testing.assert_equal(y.numpy(), xx.numpy())
  343. x[:] = 1
  344. xx[:] = 1
  345. np.testing.assert_equal(x, xx.numpy())
  346. x = np.arange(9).reshape(3, 3).astype("int32")
  347. xx = Tensor(x)
  348. y = np.array([1, 2])
  349. yy = Tensor(y)
  350. np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy())
  351. # np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) # FIXME
  352. np.testing.assert_equal(x[:, y], xx[:, y].numpy())
  353. np.testing.assert_equal(x[:, y], xx[:, yy].numpy())
  354. x_ = x.copy()
  355. x_[:, y[0]] = -1
  356. xx_ = Tensor(x_)
  357. xx[:, yy[0]] = xx_[:, yy[0]]
  358. np.testing.assert_equal(x_, xx.numpy())
  359. x_[:, y] = -1
  360. xx_ = Tensor(x_)
  361. xx[:, yy] = xx_[:, yy]
  362. np.testing.assert_equal(x_, xx.numpy())
  363. x = np.arange(9).reshape(3, 3).astype("int32")
  364. xx = Tensor(x)
  365. y = np.array([1])
  366. yy = Tensor(y)
  367. np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy())
  368. # np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) # FIXME
  369. np.testing.assert_equal(x[:, y], xx[:, y].numpy())
  370. # XXX: no way to tell whether yy is scalar or ndim=1 array
  371. np.testing.assert_equal(x[:, y], xx[:, yy].numpy())
  372. x = np.arange(9).reshape(3, 3).astype("int32")
  373. xx = Tensor(x)
  374. np.testing.assert_equal(x[[0, 1], 0], xx[[0, 1], 0].numpy())
  375. np.testing.assert_equal(x[0:2, 0], xx[0:2, 0].numpy())
  376. def test_advance_indexing_with_bool():
  377. a = np.arange(9).reshape(3, 3).astype(np.float32)
  378. b = np.array([1, 2, 3])
  379. c = np.array([1, 2, 3])
  380. aa = Tensor(a)
  381. bb = Tensor(b)
  382. cc = Tensor(c)
  383. np.testing.assert_equal(a[b == 1, c == 2], aa[bb == 1, cc == 2].numpy())
  384. a[b == 1, c == 2] = -1.0
  385. aa[bb == 1, cc == 2] = -1.0
  386. np.testing.assert_equal(a, aa.numpy())
  387. a = np.arange(9).reshape(3, 3).astype(np.float32)
  388. b = np.array([False, True, True])
  389. c = np.array([2, 0]).astype(np.int32)
  390. aa = Tensor(a)
  391. bb = Tensor(b)
  392. cc = Tensor(c)
  393. np.testing.assert_equal(a[b, c], aa[bb, cc].numpy())
  394. a[b, c] = -1.0
  395. aa[bb, cc] = -1.0
  396. np.testing.assert_equal(a, aa.numpy())
  397. d = np.array([-1, -2], dtype=np.float32)
  398. dd = Tensor(d)
  399. a[b, c] = d
  400. aa[bb, cc] = dd
  401. np.testing.assert_equal(a, aa.numpy())
  402. a = np.ones((2, 2))
  403. b = np.array([[True, False], [False, True]])
  404. aa = Tensor(a)
  405. bb = Tensor(b)
  406. np.testing.assert_equal(a[b], aa[bb].numpy())
  407. b[:] = True
  408. bb[:] = True
  409. np.testing.assert_equal(a[b], aa[bb].numpy())
  410. np.testing.assert_equal(a[:, [True, False]], aa[:, [True, False]].numpy())
  411. a = np.ones((2, 2), dtype=np.int32)
  412. b = np.array([[False, False], [False, False]])
  413. aa = Tensor(a)
  414. bb = Tensor(b)
  415. np.testing.assert_equal(a[b], aa[bb].numpy())
  416. b = np.array([False, False])
  417. bb = Tensor(b)
  418. np.testing.assert_equal(a[b], aa[bb].numpy().reshape(a[b].shape)) # FIXME
  419. a = np.arange(576).reshape(2, 3, 4, 3, 4, 2).astype("int32")
  420. aa = Tensor(a)
  421. b = (np.random.sample((2, 3, 4)) > 0.5).astype("bool")
  422. bb = Tensor(b)
  423. np.testing.assert_equal(a[b, :, 0:4:2], aa[bb, :, 0:4:2].numpy())
  424. b = (np.random.sample((4, 3, 4)) > 0.5).astype("bool")
  425. bb = Tensor(b)
  426. np.testing.assert_equal(a[..., b, 0:2], aa[..., bb, 0:2].numpy())
  427. b = (np.random.sample((3, 4, 3)) > 0.5).astype("bool")
  428. bb = Tensor(b)
  429. np.testing.assert_equal(
  430. a[:, b, 0:2, [True, False]], aa[:, bb, 0:2, [True, False]].numpy()
  431. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台