You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_network_node.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758
  1. import io
  2. import os
  3. import platform
  4. import numpy as np
  5. import pytest
  6. import megengine.core.tensor.dtype as dtype
  7. import megengine.core.tensor.megbrain_graph as G
  8. import megengine.functional as F
  9. import megengine.module as M
  10. import megengine.random as rand
  11. from megengine.core._imperative_rt.core2 import apply
  12. from megengine.core._wrap import Device
  13. from megengine.core.ops import builtin
  14. from megengine.device import (
  15. get_cuda_compute_capability,
  16. get_device_count,
  17. is_cuda_available,
  18. )
  19. from megengine.functional.debug_param import (
  20. get_execution_strategy,
  21. set_execution_strategy,
  22. )
  23. from megengine.functional.external import tensorrt_runtime_opr
  24. from megengine.jit.tracing import trace
  25. from megengine.tensor import Tensor
  26. from megengine.utils.comp_graph_tools import GraphInference
  27. from megengine.utils.network import Network as Net
  28. def check_pygraph_dump(trace_func, inp_data, expect_results, max_err=None):
  29. orig_model = io.BytesIO()
  30. inp_size = len(inp_data)
  31. out_size = len(expect_results)
  32. arg_names = ["arg_{}".format(i) for i in range(inp_size)]
  33. output_names = ["out_{}".format(i) for i in range(out_size)]
  34. trace_func.dump(
  35. orig_model,
  36. arg_names=arg_names,
  37. output_names=output_names,
  38. optimize_for_inference=False,
  39. )
  40. orig_model.seek(0)
  41. net = Net.load(orig_model)
  42. file = io.BytesIO()
  43. net.dump(file, optimize_for_inference=False)
  44. file.seek(0)
  45. graph = GraphInference(file)
  46. inp_dict = dict([(arg_names[i], inp_data[i].numpy()) for i in range(inp_size)])
  47. results = graph.run(inp_dict=inp_dict)
  48. for ind, tensor in enumerate(expect_results):
  49. if max_err:
  50. np.testing.assert_almost_equal(
  51. tensor.numpy(), results[output_names[ind]], max_err
  52. )
  53. else:
  54. np.testing.assert_equal(tensor.numpy(), results[output_names[ind]])
  55. assert tensor.dtype == results[output_names[ind]].dtype
  56. def test_elemwise():
  57. @trace(symbolic=True, capture_as_const=True)
  58. def fwd(x, y):
  59. z1 = x * y
  60. z2 = x + y
  61. z3 = z1 / z2
  62. z3 = z3 ** 3
  63. return z3
  64. x = Tensor([1.0, 2.0])
  65. y = Tensor([3.0, 5.0])
  66. result = fwd(x, y)
  67. check_pygraph_dump(fwd, [x, y], [result])
  68. def test_reduce():
  69. @trace(symbolic=True, capture_as_const=True)
  70. def fwd(data):
  71. x = data.sum(axis=2)
  72. x = x.mean(axis=1)
  73. return x
  74. data = Tensor(np.random.random((1, 32, 32)))
  75. result = fwd(data)
  76. check_pygraph_dump(fwd, [data], [result])
  77. def test_typecvt():
  78. @trace(symbolic=True, capture_as_const=True)
  79. def fwd(data):
  80. return data.astype(dtype.qint8(0.8))
  81. x = Tensor(np.random.random((2, 3)) * 255)
  82. result = fwd(x)
  83. check_pygraph_dump(fwd, [x], [result])
  84. def test_matinv():
  85. @trace(symbolic=True, capture_as_const=True)
  86. def fwd(data):
  87. return F.matinv(data)
  88. data = Tensor(np.random.random((5, 5)))
  89. result = fwd(data)
  90. check_pygraph_dump(fwd, [data], [result])
  91. @pytest.mark.parametrize(
  92. "execution_strategy", ["HEURISTIC_REPRODUCIBLE", "PROFILE_REPRODUCIBLE"]
  93. )
  94. def test_matmul(execution_strategy):
  95. @trace(symbolic=True, capture_as_const=True)
  96. def fwd(data1, data2):
  97. return F.matmul(data1, data2)
  98. old = get_execution_strategy()
  99. set_execution_strategy(execution_strategy)
  100. max_err = None
  101. if execution_strategy == "PROFILE_REPRODUCIBLE":
  102. max_err = 1e-5
  103. data1 = Tensor(np.random.random((32, 64)))
  104. data2 = Tensor(np.random.random((64, 16)))
  105. result = fwd(data1, data2)
  106. check_pygraph_dump(fwd, [data1, data2], [result], max_err=max_err)
  107. set_execution_strategy(old)
  108. def test_batchmatmul():
  109. @trace(symbolic=True, capture_as_const=True)
  110. def fwd(x, y):
  111. return F.matmul(x, y)
  112. x = Tensor(np.random.random((3, 3, 5)))
  113. y = Tensor(np.random.random((3, 5, 3)))
  114. result = fwd(x, y)
  115. check_pygraph_dump(fwd, [x, y], [result])
  116. def test_dot():
  117. @trace(symbolic=True, capture_as_const=True)
  118. def fwd(x, y):
  119. return F.dot(x, y)
  120. x = Tensor([1.0, 2.0, 3.0])
  121. y = Tensor([3.0, 4.0, 5.0])
  122. result = fwd(x, y)
  123. check_pygraph_dump(fwd, [x, y], [result])
  124. def test_svd():
  125. @trace(symbolic=True, capture_as_const=True)
  126. def fwd(data):
  127. _, out, _ = F.svd(data)
  128. return out
  129. input = Tensor(np.random.random((1, 1, 3, 3)))
  130. result = fwd(input)
  131. check_pygraph_dump(fwd, [input], [result])
  132. def test_conv():
  133. conv = M.Conv2d(3, 32, 3)
  134. @trace(symbolic=True, capture_as_const=True)
  135. def fwd(data):
  136. return conv(data)
  137. data = Tensor(np.random.random((1, 3, 32, 32)))
  138. result = fwd(data)
  139. check_pygraph_dump(fwd, [data], [result])
  140. def test_deformable_conv():
  141. if not is_cuda_available():
  142. return
  143. conv = M.DeformableConv2d(3, 32, 3)
  144. @trace(symbolic=True, capture_as_const=True)
  145. def fwd(data, offset, mask):
  146. return conv(data, offset, mask)
  147. data = Tensor(np.random.random((1, 3, 32, 32)))
  148. offset = Tensor(np.ones((32, 3 * 3 * 2, 30, 30)).astype("int32") * 5)
  149. mask = Tensor(np.ones((32, 3 * 3, 30, 30)).astype("int32"))
  150. out = fwd(data, offset, mask)
  151. check_pygraph_dump(fwd, [data, offset, mask], [out])
  152. def test_convtranspose():
  153. deconv = M.ConvTranspose2d(32, 32, 3)
  154. @trace(symbolic=True, capture_as_const=True)
  155. def fwd(data):
  156. return deconv(data)
  157. data = Tensor(np.random.random((1, 32, 32, 32)))
  158. result = fwd(data)
  159. # cu111 has 1e-7 diff
  160. check_pygraph_dump(fwd, [data], [result], 5)
  161. @pytest.mark.skip(reason="pytest aborted")
  162. def test_grouplocal():
  163. n = M.LocalConv2d(3, 32, 32, 32, 3)
  164. @trace(symbolic=True, capture_as_const=True)
  165. def fwd(data):
  166. return n(data)
  167. input = Tensor(np.random.random((1, 3, 32, 32)))
  168. result = fwd(input)
  169. check_pygraph_dump(fwd, [input], [result])
  170. def test_pooling():
  171. @trace(symbolic=True, capture_as_const=True)
  172. def fwd(data):
  173. out = F.max_pool2d(data, 2, 2)
  174. out = F.avg_pool2d(out, 2, 2)
  175. return out
  176. data = Tensor(np.random.random((1, 3, 64, 64)))
  177. result = fwd(data)
  178. check_pygraph_dump(fwd, [data], [result])
  179. def test_adaptivepooling():
  180. pool1 = M.AdaptiveMaxPool2d((2, 2))
  181. pool2 = M.AdaptiveAvgPool2d((2, 2))
  182. @trace(symbolic=True, capture_as_const=True)
  183. def fwd(data):
  184. out = pool1(data)
  185. out = pool2(out)
  186. return out
  187. input = Tensor(np.random.random((1, 3, 32, 32)))
  188. result = fwd(input)
  189. check_pygraph_dump(fwd, [input], [result])
  190. def test_roipooling():
  191. inp = Tensor(np.random.random((1, 1, 128, 128)))
  192. rois = Tensor(np.random.random((4, 5)))
  193. @trace(symbolic=True, capture_as_const=True)
  194. def fwd(inp, rois):
  195. return F.vision.roi_pooling(inp, rois, (2, 2), scale=2.0)
  196. output = fwd(inp, rois)
  197. check_pygraph_dump(fwd, [inp, rois], [output])
  198. def test_deformable_ps_roi_pooling():
  199. inp = Tensor(np.random.random((1, 256, 64, 64)).astype("float32"))
  200. rois = Tensor(np.random.random((1, 5)).astype("float32"))
  201. trans = Tensor(np.random.random((24, 2, 7, 7)).astype("float32"))
  202. pooled_h = 7
  203. pooled_w = 7
  204. sample_per_part = 4
  205. no_trans = False
  206. part_size = 7
  207. spatial_scale = 1.0 / 64
  208. trans_std = 0.1
  209. @trace(symbolic=True, capture_as_const=True)
  210. def fwd(inp, rois, trans):
  211. y = F.deformable_psroi_pooling(
  212. inp,
  213. rois,
  214. trans,
  215. no_trans,
  216. part_size,
  217. pooled_h,
  218. pooled_w,
  219. sample_per_part,
  220. spatial_scale,
  221. trans_std,
  222. )
  223. return y
  224. result = fwd(inp, rois, trans)
  225. check_pygraph_dump(fwd, [inp, rois, trans], [result])
  226. @pytest.mark.require_ngpu(1)
  227. @pytest.mark.skipif(
  228. get_cuda_compute_capability(0) < 61,
  229. reason="does not support int8 when gpu compute capability less than 6.1",
  230. )
  231. def test_convbias():
  232. @trace(symbolic=True, capture_as_const=True)
  233. def fwd(inp, weight, bias):
  234. return F.quantized.conv_bias_activation(
  235. inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="relu"
  236. )
  237. inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0))
  238. weight = Tensor(np.random.random((32, 3, 3, 3)), dtype=dtype.qint8(scale=1.0))
  239. bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0))
  240. result = fwd(inp, weight, bias)
  241. check_pygraph_dump(fwd, [inp, weight, bias], [result])
  242. @pytest.mark.skip(reason="does not support int4 when cuda version is lower than 10.2")
  243. def test_conv_bias_int4():
  244. @trace(symbolic=True, capture_as_const=True)
  245. def fwd(inp, weight, bias):
  246. return F.quantized.conv_bias_activation(
  247. inp,
  248. weight,
  249. bias,
  250. dtype=dtype.quint4(scale=1.0, zero_point=0),
  251. nonlinear_mode="relu",
  252. )
  253. inp = Tensor(
  254. np.random.random((1, 3, 64, 64)), dtype=dtype.quint4(scale=1.0, zero_point=0)
  255. )
  256. weight = Tensor(np.random.random((32, 3, 3, 3)), dtype=dtype.qint4(scale=1.0))
  257. bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0))
  258. result = fwd(inp, weight, bias)
  259. check_pygraph_dump(fwd, [inp, weight, bias], [result])
  260. def test_batch_convbias():
  261. if is_cuda_available():
  262. return
  263. @trace(symbolic=True, capture_as_const=True)
  264. def fwd(inp, weight, bias):
  265. return F.quantized.batch_conv_bias_activation(
  266. inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="relu"
  267. )
  268. inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0))
  269. weight = Tensor(np.random.random((1, 32, 3, 3, 3)), dtype=dtype.qint8(scale=1.0))
  270. bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0))
  271. result = fwd(inp, weight, bias)
  272. check_pygraph_dump(fwd, [inp, weight, bias], [result])
  273. def test_batchnorm():
  274. bn = M.BatchNorm2d(32)
  275. bn.eval()
  276. @trace(symbolic=True, capture_as_const=True)
  277. def fwd(data):
  278. return bn(data)
  279. data = Tensor(np.random.random((1, 32, 32, 32)))
  280. result = fwd(data)
  281. check_pygraph_dump(fwd, [data], [result])
  282. def test_roialign():
  283. inp = Tensor(np.random.randn(1, 1, 128, 128))
  284. rois = Tensor(np.random.random((4, 5)))
  285. @trace(symbolic=True, capture_as_const=True)
  286. def fwd(inp, rois):
  287. return F.vision.roi_align(inp, rois, (2, 2))
  288. output = fwd(inp, rois)
  289. check_pygraph_dump(fwd, [inp, rois], [output])
  290. def test_warpperspective():
  291. inp_shape = (1, 1, 4, 4)
  292. x = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  293. M_shape = (1, 3, 3)
  294. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  295. M = Tensor(
  296. np.array(
  297. [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
  298. ).reshape(M_shape)
  299. )
  300. @trace(symbolic=True, capture_as_const=True)
  301. def fwd(x, M):
  302. return F.vision.warp_perspective(x, M, (2, 2))
  303. result = fwd(x, M)
  304. check_pygraph_dump(fwd, [x, M], [result])
  305. def test_warpaffine():
  306. inp_shape = (1, 3, 3, 3)
  307. x = Tensor(np.arange(27, dtype=np.float32).reshape(inp_shape))
  308. weightv = Tensor([[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]])
  309. @trace(symbolic=True, capture_as_const=True)
  310. def fwd(x, weightv):
  311. return F.vision.warp_affine(x, weightv, (2, 2), border_mode="wrap")
  312. outp = fwd(x, weightv)
  313. check_pygraph_dump(fwd, [x, weightv], [outp])
  314. def test_remap():
  315. inp_shape = (1, 1, 4, 4)
  316. inp = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  317. map_xy_shape = (1, 2, 2, 2)
  318. map_xy = Tensor(
  319. np.array(
  320. [[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], dtype=np.float32
  321. ).reshape(map_xy_shape)
  322. )
  323. @trace(symbolic=True, capture_as_const=True)
  324. def fwd(inp, map_xy):
  325. return F.vision.remap(inp, map_xy)
  326. out = fwd(inp, map_xy)
  327. check_pygraph_dump(fwd, [inp, map_xy], [out])
  328. def test_resize():
  329. x = Tensor(np.random.randn(10, 3, 32, 32))
  330. @trace(symbolic=True, capture_as_const=True)
  331. def fwd(x):
  332. return F.vision.interpolate(x, size=(16, 16), mode="bilinear")
  333. out = fwd(x)
  334. check_pygraph_dump(fwd, [x], [out])
  335. def test_index_onehot():
  336. src = Tensor([[1.0, 2.0]])
  337. index = Tensor([0])
  338. @trace(symbolic=True, capture_as_const=True)
  339. def fwd(src, index):
  340. return F.indexing_one_hot(src, index)
  341. out = fwd(src, index)
  342. check_pygraph_dump(fwd, [src, index], [out])
  343. def test_set_onehot():
  344. x = Tensor(np.arange(1, 4, dtype=np.int32))
  345. @trace(symbolic=True, capture_as_const=True)
  346. def fwd(x):
  347. return F.one_hot(x, num_classes=4)
  348. out = fwd(x)
  349. check_pygraph_dump(fwd, [x], [out])
  350. def test_copy():
  351. x = Tensor([1, 2, 3])
  352. @trace(symbolic=True, capture_as_const=True)
  353. def fwd(x):
  354. return x.to("cpu0:0")
  355. o = fwd(x)
  356. check_pygraph_dump(fwd, [x], [o])
  357. def test_argsort():
  358. @trace(symbolic=True, capture_as_const=True)
  359. def fwd(data):
  360. return F.argsort(data, True)
  361. data = Tensor([1.0, 2.0, 3.0, 5.0])
  362. result = fwd(data)
  363. check_pygraph_dump(fwd, [data], [result])
  364. def test_argmax_min():
  365. @trace(symbolic=True, capture_as_const=True)
  366. def fwd(data):
  367. return F.argmax(data), F.argmin(data)
  368. data = Tensor(np.random.random((10, 10)))
  369. result = fwd(data)
  370. check_pygraph_dump(fwd, [data], result)
  371. def test_condtake():
  372. mask = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_))
  373. x = Tensor(np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32))
  374. @trace(symbolic=True, capture_as_const=True)
  375. def fwd(mask, x):
  376. v, index = F.cond_take(mask, x)
  377. return v, index
  378. v, index = fwd(mask, x)
  379. check_pygraph_dump(fwd, [mask, x], [v, index])
  380. def test_topk():
  381. x = Tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  382. @trace(symbolic=True, capture_as_const=True)
  383. def fwd(x):
  384. top, indices = F.topk(x, 5)
  385. return top, indices
  386. top, indices = fwd(x)
  387. check_pygraph_dump(fwd, [x], [top, indices])
  388. def test_random():
  389. @trace(symbolic=True, capture_as_const=True)
  390. def fwd():
  391. x = rand.uniform(size=(2, 2))
  392. y = rand.normal(size=(1, 3, 3, 3))
  393. return x, y
  394. x, y = fwd()
  395. check_pygraph_dump(fwd, [], [x, y])
  396. def test_tensor_gen():
  397. @trace(symbolic=True, capture_as_const=True)
  398. def fwd():
  399. a = F.linspace(3, 10, 3, device=Device("xpux").to_c())
  400. b = F.eye(3, device=Device("xpux").to_c())
  401. return a, b
  402. a, b = fwd()
  403. check_pygraph_dump(fwd, [], [a, b])
  404. def test_getvarshape():
  405. op = builtin.GetVarShape(axis=1)
  406. @trace(symbolic=True, capture_as_const=True)
  407. def fwd(data):
  408. return apply(op, data)[0]
  409. data = Tensor(np.random.random((1, 2, 3, 4)))
  410. result = fwd(data)
  411. check_pygraph_dump(fwd, [data], [result])
  412. def test_concat():
  413. @trace(symbolic=True, capture_as_const=True)
  414. def fwd(data1, data2):
  415. return F.concat([data1, data2], axis=1)
  416. x = Tensor(np.random.random((2, 3)))
  417. y = Tensor(np.random.random((2, 5)))
  418. result = fwd(x, y)
  419. check_pygraph_dump(fwd, [x, y], [result])
  420. def test_broadcast():
  421. inp = Tensor([[1], [2], [3], [4]])
  422. @trace(symbolic=True, capture_as_const=True)
  423. def fwd(inp):
  424. return F.broadcast_to(inp, (4, 4))
  425. out = fwd(inp)
  426. check_pygraph_dump(fwd, [inp], [out])
  427. def test_identity():
  428. @trace(symbolic=True, capture_as_const=True)
  429. def fwd(data):
  430. return F.copy(data)
  431. data = Tensor([1.0, 2.0])
  432. result = fwd(data)
  433. check_pygraph_dump(fwd, [data], [result])
  434. @pytest.mark.skip(reason="advance indexing trace error")
  435. def test_nms():
  436. x = np.zeros((100, 4))
  437. np.random.seed(42)
  438. x[:, :2] = np.random.rand(100, 2) * 20
  439. x[:, 2:] = np.random.rand(100, 2) * 20 + 100
  440. scores = Tensor(np.random.rand(100))
  441. inp = Tensor(x)
  442. @trace(symbolic=True, capture_as_const=True)
  443. def fwd(inp, scores):
  444. return F.nn.nms(inp, scores, iou_thresh=0.7, max_output=3)
  445. result = fwd(inp, scores)
  446. check_pygraph_dump(fwd, [inp, scores], [result])
  447. def test_dimshuffle():
  448. inp = Tensor([1, 2, 3, 4])
  449. @trace(symbolic=True, capture_as_const=True)
  450. def fwd(inp):
  451. return inp.T
  452. out = fwd(inp)
  453. check_pygraph_dump(fwd, [inp], [out])
  454. def test_reshape():
  455. @trace(symbolic=True, capture_as_const=True)
  456. def fwd(data):
  457. return data.reshape((1, 8))
  458. data = Tensor(np.random.random((1, 2, 2, 2)))
  459. result = fwd(data)
  460. check_pygraph_dump(fwd, [data], [result])
  461. def test_add_remove_axis():
  462. @trace(symbolic=True, capture_as_const=True)
  463. def fwd(data):
  464. x = F.expand_dims(data, [0, 0])
  465. y = F.squeeze(x, 0)
  466. return y
  467. data = Tensor([1.0, 2.0])
  468. result = fwd(data)
  469. check_pygraph_dump(fwd, [data], [result])
  470. @pytest.mark.parametrize("mode", ["get", "set", "inc"])
  471. def test_subtensor(mode):
  472. items = [[0, True, True, True, False], [1, False, False, False, True]]
  473. data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random(2))]
  474. if mode == "get":
  475. op = builtin.Subtensor(items)
  476. data = data[:1]
  477. if mode == "set":
  478. op = builtin.SetSubtensor(items)
  479. if mode == "inc":
  480. op = builtin.IncrSubtensor(items)
  481. tensors = [Tensor(0), Tensor(4), Tensor(2), Tensor(3)]
  482. @trace(symbolic=True, capture_as_const=True)
  483. def fwd(*tensors):
  484. return apply(op, *tensors)[0]
  485. result = fwd(*data, *tensors)
  486. check_pygraph_dump(fwd, data + tensors, [result])
  487. @pytest.mark.parametrize("mode", ["get", "set", "inc"])
  488. def test_advance_indexing(mode):
  489. items = [[0, False, False, False, True]]
  490. tensors = [Tensor([0, 4, 2])]
  491. data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 5)))]
  492. if mode == "get":
  493. op = builtin.IndexingMultiAxisVec(items)
  494. data = data[:1]
  495. if mode == "set":
  496. op = builtin.IndexingSetMultiAxisVec(items)
  497. if mode == "inc":
  498. op = builtin.IndexingIncrMultiAxisVec(items)
  499. @trace(symbolic=True, capture_as_const=True)
  500. def fwd(*tensors):
  501. return apply(op, *tensors)[0]
  502. result = fwd(*data, *tensors)
  503. check_pygraph_dump(fwd, data + tensors, [result])
  504. @pytest.mark.parametrize("mode", ["get", "set", "inc"])
  505. def test_mesh_indexing(mode):
  506. items = [[0, True, True, True, False], [1, False, False, False, True]]
  507. tensors = [Tensor(0), Tensor(5), Tensor(2), Tensor([1, 3])]
  508. data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 2)))]
  509. if mode == "get":
  510. op = builtin.IndexingMultiAxisVec(items)
  511. data = data[:1]
  512. if mode == "set":
  513. op = builtin.IndexingSetMultiAxisVec(items)
  514. if mode == "inc":
  515. op = builtin.IndexingIncrMultiAxisVec(items)
  516. @trace(symbolic=True, capture_as_const=True)
  517. def fwd(*tensors):
  518. return apply(op, *tensors)[0]
  519. result = fwd(*data, *tensors)
  520. check_pygraph_dump(fwd, data + tensors, [result])
  521. @pytest.mark.parametrize("mode", ["get", "set", "inc"])
  522. def test_batch_mesh_indexing(mode):
  523. items = [[1, False, False, False, True], [2, False, False, False, True]]
  524. tensors = [Tensor([[0, 2], [0, 2]]), Tensor([[0, 1, 2], [1, 2, 3]])]
  525. data = [Tensor(np.random.random((2, 3, 4))), Tensor(np.random.random((2, 2, 3)))]
  526. if mode == "get":
  527. op = builtin.BatchedMeshIndexing(items)
  528. data = data[:1]
  529. if mode == "set":
  530. op = builtin.BatchedSetMeshIndexing(items)
  531. if mode == "inc":
  532. op = builtin.BatchedIncrMeshIndexing(items)
  533. @trace(symbolic=True, capture_as_const=True)
  534. def fwd(*tensors):
  535. return apply(op, *tensors)[0]
  536. result = fwd(*data, *tensors)
  537. check_pygraph_dump(fwd, data + tensors, [result])
  538. @pytest.mark.skip(reason="tmp skip")
  539. def test_assert_equal():
  540. g = G.Graph()
  541. inp1 = g.make_h2d(dtype=np.float32, device="xpux")
  542. inp2 = g.make_h2d(dtype=np.float32, device="xpux")
  543. op = builtin.AssertEqual(maxerr=1e-5)
  544. out = G.apply_normal_varnode(op, inp1._node, inp2._node)[0]
  545. g.compile(out)
  546. file = io.BytesIO()
  547. out_model = G.dump_graph([out])
  548. file.write(out_model[0])
  549. file.seek(0)
  550. net = Net.load(file)
  551. dump_file = io.BytesIO()
  552. net.dump(dump_file)
  553. dump_file.seek(0)
  554. g = GraphInference(dump_file)
  555. g.run(np.array([1.0, 2.0]), np.array([1.0, 2.0]))
  556. def test_elemwise_multitype():
  557. op = builtin.ElemwiseMultiType(mode="qadd", dtype=dtype.qint32(2.0))
  558. @trace(symbolic=True, capture_as_const=True)
  559. def fwd(x, y):
  560. return apply(op, x, y)[0]
  561. x = Tensor(np.random.random(10) * 10, dtype=dtype.qint8(2.0))
  562. y = Tensor(np.random.random(10) * 10, dtype=dtype.qint8(2.0))
  563. result = fwd(x, y)
  564. check_pygraph_dump(fwd, [x, y], [result])
  565. def test_cvtcolor():
  566. inp = np.random.randn(3, 3, 3, 3).astype(np.float32)
  567. x = Tensor(inp)
  568. @trace(symbolic=True, capture_as_const=True)
  569. def fwd(inp):
  570. return F.vision.cvt_color(inp, mode="RGB2GRAY")
  571. result = fwd(x)
  572. check_pygraph_dump(fwd, [x], [result])