You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor.py 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896
  1. # -*- coding: utf-8 -*-
  2. import os
  3. import platform
  4. import numpy as np
  5. import pytest
  6. from utils import get_var_value, make_tensor, opr_test
  7. import megengine.functional as F
  8. from megengine import tensor
  9. from megengine.core._trace_option import use_symbolic_shape
  10. from megengine.core.tensor import megbrain_graph as G
  11. from megengine.core.tensor.utils import astensor1d
  12. from megengine.jit import trace
  13. from megengine.utils.network import Network, set_symbolic_shape
  14. from megengine.utils.network_node import VarNode
  15. def test_eye():
  16. dtypes = [np.float32, np.bool]
  17. cases = [{"input": [10, 20]}, {"input": [30]}]
  18. for dtype in dtypes:
  19. for case in cases:
  20. np.testing.assert_allclose(
  21. F.eye(case["input"], dtype=dtype).numpy(),
  22. np.eye(*case["input"]).astype(dtype),
  23. )
  24. np.testing.assert_allclose(
  25. F.eye(*case["input"], dtype=dtype).numpy(),
  26. np.eye(*case["input"]).astype(dtype),
  27. )
  28. np.testing.assert_allclose(
  29. F.eye(tensor(case["input"]), dtype=dtype).numpy(),
  30. np.eye(*case["input"]).astype(dtype),
  31. )
  32. @pytest.mark.parametrize("is_varnode", [False, True])
  33. def test_diag(is_varnode):
  34. if is_varnode:
  35. network = Network()
  36. else:
  37. network = None
  38. shapes = [(10, 10), (6, 9), (8, 7), (8,)]
  39. cases = []
  40. for shp in shapes:
  41. cases.append({"input": [np.random.random(shp).astype("float32")]})
  42. for axis in range(-2, 3):
  43. def run(data):
  44. return F.diag(data, k=axis)
  45. opr_test(cases, run, ref_fn=lambda x: np.diag(x, axis), network=network)
  46. def test_full():
  47. shape = (2, 3)
  48. values = [True, 4, 5.0]
  49. for value in values:
  50. np.testing.assert_allclose(F.full(shape, value).numpy(), np.full(shape, value))
  51. assert F.full(shape, value).dtype == tensor(value).dtype
  52. @pytest.mark.parametrize("is_varnode", [True, False])
  53. def test_concat(is_varnode):
  54. if is_varnode:
  55. network = Network()
  56. else:
  57. network = None
  58. def get_data_shape(length: int):
  59. return (length, 2, 3)
  60. data1 = np.random.random(get_data_shape(5)).astype("float32")
  61. data2 = np.random.random(get_data_shape(6)).astype("float32")
  62. data3 = np.random.random(get_data_shape(7)).astype("float32")
  63. def run(data1, data2):
  64. return F.concat([data1, data2])
  65. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  66. opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)
  67. @pytest.mark.parametrize("is_varnode", [True, False])
  68. def test_condtake(is_varnode):
  69. if is_varnode:
  70. network = Network()
  71. else:
  72. network = None
  73. x = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
  74. y = np.array([[True, False, True], [False, True, True]])
  75. xx = make_tensor(x, network)
  76. yy = make_tensor(y, network)
  77. val, idx = F.cond_take(yy, xx)
  78. if is_varnode:
  79. np.testing.assert_equal(get_var_value(val), x[y])
  80. np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0])
  81. else:
  82. np.testing.assert_equal(val.numpy(), x[y])
  83. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  84. @pytest.mark.parametrize("is_varnode", [True, False])
  85. def test_concat_device(is_varnode):
  86. if is_varnode:
  87. network = Network()
  88. else:
  89. network = None
  90. data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0")
  91. data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")
  92. out = F.concat([data1, data2], device="cpu0")
  93. assert str(out.device).split(":")[0] == "cpu0"
  94. @pytest.mark.parametrize("is_varnode", [True, False])
  95. def test_stack(is_varnode):
  96. if is_varnode:
  97. network = Network()
  98. else:
  99. network = None
  100. data1 = np.random.random((3, 2, 2)).astype("float32")
  101. data2 = np.random.random((3, 2, 2)).astype("float32")
  102. data3 = np.random.random((3, 2, 2)).astype("float32")
  103. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  104. for ai in range(3):
  105. def run(data1, data2):
  106. return F.stack([data1, data2], axis=ai)
  107. opr_test(
  108. cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
  109. )
  110. @pytest.mark.parametrize("is_varnode", [True, False])
  111. def test_split_basic(is_varnode):
  112. if is_varnode:
  113. network = Network()
  114. saved_symbolic_shape = set_symbolic_shape(False)
  115. else:
  116. network = None
  117. data = np.random.random((2, 3, 4, 5)).astype(np.float32)
  118. inp = make_tensor(data, network)
  119. mge_out0 = F.split(inp, 2, axis=3)
  120. mge_out1 = F.split(inp, [3], axis=3)
  121. np_out = np.split(data, [3, 5], axis=3)
  122. assert len(mge_out0) == 2
  123. assert len(mge_out1) == 2
  124. np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
  125. np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])
  126. np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
  127. np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])
  128. try:
  129. F.split(inp, 4)
  130. assert False
  131. except ValueError as e:
  132. pass
  133. try:
  134. F.split(inp, [3, 2, 5], axis=3)
  135. assert False
  136. except ValueError as e:
  137. assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]"
  138. if is_varnode:
  139. set_symbolic_shape(saved_symbolic_shape)
  140. @pytest.mark.parametrize("symbolic", [None, False, True])
  141. def test_split(symbolic):
  142. inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32)
  143. inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32)
  144. def ref(inp, nsplits_or_sections, axis):
  145. return np.split(inp, nsplits_or_sections, axis)
  146. def func(inp, nsplits_or_sections, axis):
  147. return F.split(inp, nsplits_or_sections, axis)
  148. cases = [
  149. (inp1, 2, 3),
  150. (inp1, [3], 3),
  151. (inp1, [3, 3, 5], 3),
  152. (inp2, 2, 3),
  153. (inp2, [3], 3),
  154. (inp2, [3, 3, 5], 3),
  155. ]
  156. for case in cases:
  157. if symbolic is None:
  158. fn = func
  159. else:
  160. fn = trace(symbolic=symbolic)(func)
  161. for i in range(3 if symbolic is not None else 1):
  162. ref_out = ref(*case)
  163. out = fn(tensor(case[0]), case[1], case[2])
  164. assert len(ref_out) == len(out)
  165. for idx in range(len(ref_out)):
  166. np.testing.assert_equal(ref_out[idx], out[idx].numpy())
  167. @pytest.mark.parametrize("is_varnode", [True, False])
  168. def test_reshape(is_varnode):
  169. if is_varnode:
  170. network = Network()
  171. else:
  172. network = None
  173. x = np.arange(6, dtype="float32")
  174. xx = make_tensor(x, network)
  175. y = x.reshape(1, 2, 3)
  176. for shape in [
  177. (1, 2, 3),
  178. (1, -1, 3),
  179. (1, make_tensor(-1, network), 3),
  180. np.array([1, -1, 3], dtype="int32"),
  181. make_tensor([1, -1, 3], network),
  182. ]:
  183. yy = F.reshape(xx, shape)
  184. np.testing.assert_equal(yy.numpy(), y)
  185. @pytest.mark.parametrize("is_varnode", [True, False])
  186. def test_broadcast_auto_infer(is_varnode):
  187. if is_varnode:
  188. network = Network()
  189. else:
  190. network = None
  191. x = np.random.random((1, 2, 3)).astype(np.float32)
  192. xx = make_tensor(x, network)
  193. for shape in [
  194. (1, 2, 3),
  195. (1, None, 3),
  196. ]:
  197. yy = F.broadcast_to(xx, shape)
  198. np.testing.assert_equal(yy.numpy(), x)
  199. with pytest.raises(ValueError):
  200. F.broadcast_to(xx, (1, -1, 3))
  201. with pytest.raises(ValueError):
  202. F.broadcast_to(xx, (None, 1, 2, 3))
  203. F.broadcast_to(xx, (1, None, 2, 3))
  204. t = make_tensor(2, network)
  205. F.broadcast_to(xx, (t, None, 2, 3))
  206. @pytest.mark.parametrize("is_trace", [True, False])
  207. def test_reshape_on_empty_tensor(is_trace):
  208. input1_shape = (100, 0, 1)
  209. output1_shape = (100, 0, 10)
  210. data1 = tensor(np.random.random(input1_shape).astype(np.float32))
  211. input2_shape = (10, 0)
  212. output2_shape = (0,)
  213. data2 = tensor(np.random.random(input2_shape).astype(np.float32))
  214. input3_shape = (10, 0, 10)
  215. output3_shape = (0, 1, 2, 3)
  216. data3 = tensor(np.random.random(input3_shape).astype(np.float32))
  217. def comp(out, target_shp):
  218. assert out._tuple_shape == target_shp
  219. def func(x, shp):
  220. return F.reshape(x, shp)
  221. cases = [
  222. [data1, output1_shape],
  223. [data2, output2_shape],
  224. [data3, output3_shape],
  225. ]
  226. def test(func, inp, comp, target_shp):
  227. out = func(inp, target_shp)
  228. comp(out, target_shp)
  229. if is_trace:
  230. for symbolic in [False, True]:
  231. for inp, target_shp in cases:
  232. func_traced = trace(symbolic=symbolic)(func)
  233. test(func_traced, inp, comp, target_shp)
  234. test(func_traced, inp, comp, target_shp)
  235. test(func_traced, inp, comp, target_shp)
  236. else:
  237. for inp, target_shp in cases:
  238. test(func, inp, comp, target_shp)
  239. @pytest.mark.parametrize("is_varnode", [True, False])
  240. def test_reshape_shape_inference(is_varnode):
  241. if is_varnode:
  242. network = Network()
  243. saved_symbolic_shape = set_symbolic_shape(False)
  244. else:
  245. network = None
  246. x_shape_known = make_tensor([1, 2, 3, 4], network)
  247. x_shape_unknown = F.broadcast_to(
  248. make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum()
  249. )
  250. tshp_unknown = astensor1d(
  251. (make_tensor([2], network), make_tensor([2], network)), x_shape_known
  252. )
  253. tshp_known = astensor1d((2, 2), x_shape_known)
  254. tshp_known_unspec = astensor1d((2, -1), x_shape_known)
  255. def check_shape(output, target):
  256. source = output.shape
  257. if isinstance(source, tensor):
  258. source = source.numpy()
  259. np.testing.assert_equal(source, target.shape)
  260. def func(x, target_shape):
  261. return x.reshape(target_shape)
  262. cases = [
  263. {"input": [x_shape_known, tshp_unknown], "output": [np.zeros((2, 2)),]},
  264. {"input": [x_shape_unknown, tshp_unknown], "output": [np.zeros((2, 2)),]},
  265. {"input": [x_shape_known, tshp_known], "output": [np.zeros((2, 2)),]},
  266. {"input": [x_shape_known, tshp_known_unspec], "output": [np.zeros((2, 2)),]},
  267. {"input": [x_shape_unknown, tshp_known], "output": [np.zeros((2, 2)),]},
  268. {"input": [x_shape_unknown, tshp_known_unspec], "output": [np.zeros((2, 2)),]},
  269. ]
  270. opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)
  271. if is_varnode:
  272. set_symbolic_shape(saved_symbolic_shape)
  273. @pytest.mark.parametrize("is_varnode", [True, False])
  274. def test_squeeze(is_varnode):
  275. if is_varnode:
  276. network = Network()
  277. saved_symbolic_shape = set_symbolic_shape(False)
  278. else:
  279. network = None
  280. x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
  281. xx = make_tensor(x, network)
  282. for axis in [None, 3, -4, (3, -4)]:
  283. y = np.squeeze(x, axis)
  284. yy = F.squeeze(xx, axis)
  285. np.testing.assert_equal(y, yy.numpy())
  286. if is_varnode:
  287. set_symbolic_shape(saved_symbolic_shape)
  288. @pytest.mark.parametrize("is_varnode", [True, False])
  289. def test_expand_dims(is_varnode):
  290. if is_varnode:
  291. network = Network()
  292. else:
  293. network = None
  294. x = np.arange(6, dtype="float32").reshape(2, 3)
  295. xx = make_tensor(x, network)
  296. for axis in [2, -3, (3, -4), (1, -4)]:
  297. y = np.expand_dims(x, axis)
  298. yy = F.expand_dims(xx, axis)
  299. np.testing.assert_equal(y, yy.numpy())
  300. def test_expand_dims_for_scalar():
  301. x = np.array(1, dtype="float32")
  302. xx = make_tensor(x, None)
  303. for axis in [0, -1, (0, 1), (-1, -2), (0, -1)]:
  304. y = np.expand_dims(x, axis)
  305. yy = F.expand_dims(xx, axis)
  306. np.testing.assert_equal(y, yy.numpy())
  307. for axis in [1, -2, (1, 2), (-2, -3)]:
  308. np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis)
  309. np.testing.assert_raises(RuntimeError, F.expand_dims, xx, axis)
  310. @pytest.mark.parametrize("is_varnode", [True, False])
  311. def test_elemwise_dtype_promotion(is_varnode):
  312. if is_varnode:
  313. network = Network()
  314. else:
  315. network = None
  316. x = np.random.rand(2, 3).astype("float32")
  317. y = np.random.rand(1, 3).astype("float16")
  318. xx = make_tensor(x, network)
  319. yy = make_tensor(y, network)
  320. z = xx * yy
  321. np.testing.assert_equal(z.numpy(), x * y)
  322. z = xx + y
  323. np.testing.assert_equal(z.numpy(), x + y)
  324. z = x - yy
  325. np.testing.assert_equal(z.numpy(), x - y)
  326. @pytest.mark.parametrize("is_varnode", [True, False])
  327. def test_linspace(is_varnode):
  328. if is_varnode:
  329. network = Network()
  330. else:
  331. network = None
  332. cases = [
  333. {"input": [1, 9, 9]},
  334. {"input": [3, 10, 8]},
  335. ]
  336. opr_test(
  337. cases,
  338. F.linspace,
  339. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  340. network=network,
  341. )
  342. cases = [
  343. {"input": [9, 1, 9]},
  344. {"input": [10, 3, 8]},
  345. ]
  346. opr_test(
  347. cases,
  348. F.linspace,
  349. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  350. network=network,
  351. )
  352. cases = [
  353. {"input": [1, make_tensor(9, network), 9]},
  354. {"input": [make_tensor(1, network), 9, make_tensor(9, network)]},
  355. ]
  356. opr_test(
  357. cases,
  358. F.linspace,
  359. ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
  360. network=network,
  361. )
  362. @pytest.mark.parametrize("is_varnode", [True, False])
  363. def test_arange(is_varnode):
  364. if is_varnode:
  365. network = Network()
  366. else:
  367. network = None
  368. cases = [
  369. {"input": [1, 9, 1]},
  370. {"input": [2, 10, 2]},
  371. ]
  372. opr_test(
  373. cases,
  374. F.arange,
  375. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  376. network=network,
  377. )
  378. cases = [
  379. {"input": [9, 1, -1]},
  380. {"input": [10, 2, -2]},
  381. ]
  382. opr_test(
  383. cases,
  384. F.arange,
  385. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  386. network=network,
  387. )
  388. cases = [
  389. {"input": [9.3, 1.2, -0.5]},
  390. {"input": [10.3, 2.1, -1.7]},
  391. ]
  392. opr_test(
  393. cases,
  394. F.arange,
  395. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  396. network=network,
  397. )
  398. @pytest.mark.parametrize("is_varnode", [True, False])
  399. def test_round(is_varnode):
  400. if is_varnode:
  401. network = Network()
  402. else:
  403. network = None
  404. data1_shape = (15,)
  405. data2_shape = (25,)
  406. data1 = np.random.random(data1_shape).astype(np.float32)
  407. data2 = np.random.random(data2_shape).astype(np.float32)
  408. cases = [{"input": data1}, {"input": data2}]
  409. opr_test(cases, F.round, ref_fn=np.round, network=network)
  410. @pytest.mark.parametrize("is_varnode", [True, False])
  411. def test_flatten(is_varnode):
  412. if is_varnode:
  413. network = Network()
  414. else:
  415. network = None
  416. data0_shape = (2, 3, 4, 5)
  417. data1_shape = (4, 5, 6, 7)
  418. data0 = np.random.random(data0_shape).astype(np.float32)
  419. data1 = np.random.random(data1_shape).astype(np.float32)
  420. cases = [
  421. {"input": data0, "output": data0.flatten()},
  422. {"input": data1, "output": data1.flatten()},
  423. ]
  424. opr_test(cases, F.flatten, network=network)
  425. cases = [
  426. {"input": data0, "output": data0.reshape(2, -1)},
  427. {"input": data1, "output": data1.reshape(4, -1)},
  428. ]
  429. opr_test(cases, F.flatten, start_axis=1, network=network)
  430. cases = [
  431. {"input": data0, "output": data0.reshape(2, 3, -1)},
  432. {"input": data1, "output": data1.reshape(4, 5, -1)},
  433. ]
  434. opr_test(cases, F.flatten, start_axis=2, network=network)
  435. cases = [
  436. {"input": data0, "output": data0.reshape(2, -1, 5)},
  437. {"input": data1, "output": data1.reshape(4, -1, 7)},
  438. ]
  439. opr_test(
  440. cases, F.flatten, start_axis=1, end_axis=2, network=network,
  441. )
  442. @pytest.mark.parametrize("is_varnode", [True, False])
  443. def test_broadcast(is_varnode):
  444. if is_varnode:
  445. network = Network()
  446. else:
  447. network = None
  448. input1_shape = (20, 30)
  449. output1_shape = (30, 20, 30)
  450. data1 = np.random.random(input1_shape).astype(np.float32)
  451. input2_shape = (10, 1)
  452. output2_shape = (20, 10, 20)
  453. data2 = np.random.random(input2_shape).astype(np.float32)
  454. input3_shape = (10, 10)
  455. output3_shape = (10, 10)
  456. data3 = np.random.random(input3_shape).astype(np.float32)
  457. cases = [
  458. {
  459. "input": [data1, output1_shape],
  460. "output": np.broadcast_to(data1, output1_shape),
  461. },
  462. {
  463. "input": [data2, output2_shape],
  464. "output": np.broadcast_to(data2, output2_shape),
  465. },
  466. {
  467. "input": [data3, output3_shape],
  468. "output": np.broadcast_to(data3, output3_shape),
  469. },
  470. ]
  471. opr_test(cases, F.broadcast_to, network=network)
  472. x = F.ones((2, 1, 3))
  473. with pytest.raises(RuntimeError):
  474. F.broadcast_to(x, (2, 3, 4))
  475. with pytest.raises(RuntimeError):
  476. F.broadcast_to(x, (4, 1, 3))
  477. with pytest.raises(RuntimeError):
  478. F.broadcast_to(x, (1, 3))
  479. @pytest.mark.parametrize("is_trace", [True, False])
  480. def test_broadcast_on_empty_tensor(is_trace):
  481. input1_shape = (100, 0, 1)
  482. output1_shape = (100, 0, 10)
  483. data1 = tensor(np.random.random(input1_shape).astype(np.float32))
  484. input2_shape = (10, 0)
  485. output2_shape = (10, 10, 0)
  486. data2 = tensor(np.random.random(input2_shape).astype(np.float32))
  487. input3_shape = (0, 0, 1, 10)
  488. output3_shape = (10, 0, 0, 10, 10)
  489. data3 = tensor(np.random.random(input3_shape).astype(np.float32))
  490. def comp(out, target_shp):
  491. assert out._tuple_shape == target_shp
  492. def func(x, shp):
  493. return F.broadcast_to(x, shp)
  494. cases = [
  495. [data1, output1_shape],
  496. [data2, output2_shape],
  497. [data3, output3_shape],
  498. ]
  499. def test(func, inp, comp, target_shp):
  500. out = func(inp, target_shp)
  501. comp(out, target_shp)
  502. if is_trace:
  503. for symbolic in [False, True]:
  504. for inp, target_shp in cases:
  505. func_traced = trace(symbolic=symbolic)(func)
  506. test(func_traced, inp, comp, target_shp)
  507. test(func_traced, inp, comp, target_shp)
  508. test(func_traced, inp, comp, target_shp)
  509. else:
  510. for inp, target_shp in cases:
  511. test(func, inp, comp, target_shp)
  512. @pytest.mark.parametrize("is_varnode", [True, False])
  513. def test_utils_astensor1d(is_varnode):
  514. if is_varnode:
  515. network = Network()
  516. else:
  517. network = None
  518. reference = make_tensor(0, network)
  519. # literal
  520. x = [1, 2, 3]
  521. for dtype in [None, "float32"]:
  522. xx = astensor1d(x, reference, dtype=dtype)
  523. assert isinstance(xx, type(reference))
  524. np.testing.assert_equal(xx.numpy(), x)
  525. # numpy array
  526. x = np.asarray([1, 2, 3], dtype="int32")
  527. for dtype in [None, "float32"]:
  528. xx = astensor1d(x, reference, dtype=dtype)
  529. assert isinstance(xx, type(reference))
  530. np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x)
  531. # tensor
  532. x = make_tensor([1, 2, 3], network)
  533. for dtype in [None, "float32"]:
  534. xx = astensor1d(x, reference, dtype=dtype)
  535. assert isinstance(xx, type(reference))
  536. np.testing.assert_equal(xx.numpy(), x.numpy())
  537. # mixed
  538. x = [1, make_tensor(2, network), 3]
  539. for dtype in [None, "float32"]:
  540. xx = astensor1d(x, reference, dtype=dtype)
  541. assert isinstance(xx, type(reference))
  542. np.testing.assert_equal(xx.numpy(), [1, 2, 3])
  543. def test_device():
  544. x = tensor([1, 2, 3], dtype="float32")
  545. y1 = F.eye(x.shape, dtype="float32")
  546. y2 = F.eye(x.shape, dtype="float32", device=None)
  547. np.testing.assert_almost_equal(y1.numpy(), y2.numpy())
  548. y3 = F.eye(x.shape, dtype="float32", device="xpux")
  549. y4 = F.eye(x.shape, dtype="float32", device=x.device)
  550. np.testing.assert_almost_equal(y3.numpy(), y4.numpy())
  551. y5 = F.full((3, 2), 4, device=x.device)
  552. y6 = F.full((3, 2), 4, device="xpux")
  553. np.testing.assert_almost_equal(y5.numpy(), y6.numpy())
  554. @pytest.mark.parametrize("is_varnode", [True, False])
  555. def test_identity(is_varnode):
  556. if is_varnode:
  557. network = Network()
  558. else:
  559. network = None
  560. x = make_tensor(np.random.random((5, 10)).astype(np.float32), network)
  561. y = F.copy(x)
  562. np.testing.assert_equal(y.numpy(), x)
  563. def copy_test(dst, src, network):
  564. data = np.random.random((2, 3)).astype(np.float32)
  565. x = make_tensor(data, device=src, network=network)
  566. y = F.copy(x, dst)
  567. assert np.allclose(data, y.numpy())
  568. if network is None:
  569. z = x.to(dst)
  570. assert np.allclose(data, z.numpy())
  571. @pytest.mark.require_ngpu(1)
  572. @pytest.mark.parametrize("is_varnode", [True, False])
  573. def test_copy_h2d(is_varnode):
  574. if is_varnode:
  575. network = Network()
  576. else:
  577. network = None
  578. copy_test("cpu0", "gpu0", network=network)
  579. @pytest.mark.require_ngpu(1)
  580. @pytest.mark.parametrize("is_varnode", [True, False])
  581. def test_copy_d2h(is_varnode):
  582. if is_varnode:
  583. network = Network()
  584. else:
  585. network = None
  586. copy_test("gpu0", "cpu0", network=network)
  587. @pytest.mark.require_ngpu(2)
  588. @pytest.mark.parametrize("is_varnode", [True, False])
  589. def test_copy_d2d(is_varnode):
  590. if is_varnode:
  591. network = Network()
  592. else:
  593. network = None
  594. copy_test("gpu0", "gpu1", network=network)
  595. copy_test("gpu0:0", "gpu0:1", network=network)
  596. @pytest.mark.require_ngpu(2)
  597. @pytest.mark.parametrize(
  598. "shape, device_src, device_dst",
  599. [
  600. ((0,), "cpu0", "cpu0"),
  601. ((10, 0), "cpu0", "cpu1"),
  602. ((2, 0, 3), "cpu0", "gpu0"),
  603. ((1, 0, 1, 0), "gpu0", "cpu0"),
  604. ((2, 3, 4, 5, 0), "gpu0", "gpu1"),
  605. ],
  606. )
  607. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  608. def test_copy_empty(shape, device_src, device_dst, is_symbolic):
  609. inp = tensor(np.random.randn(*shape).astype("float32"), device=device_src)
  610. def func(inp):
  611. return F.copy(inp, device_dst)
  612. if is_symbolic is not None:
  613. func = trace(symbolic=is_symbolic)(func)
  614. for _ in range(3):
  615. out = func(inp)
  616. assert out.numpy().shape == shape
  617. assert out.device == device_dst
  618. if is_symbolic is None:
  619. break
  620. @pytest.mark.parametrize(
  621. "shape, repeats, axis",
  622. [
  623. ((2,), 2, 0),
  624. ((2, 3, 4, 5), 3, 0),
  625. ((2, 3, 4, 5), 4, 3),
  626. ((2,), 2, None),
  627. ((2, 3, 4, 5), 3, None),
  628. ((), 1, None),
  629. ((), 10, None),
  630. ],
  631. )
  632. @pytest.mark.parametrize("is_varnode", [True, False])
  633. def test_repeat(shape, repeats, axis, is_varnode):
  634. if is_varnode:
  635. network = Network()
  636. else:
  637. network = None
  638. def repeat_func(inp):
  639. return F.repeat(inp=inp, repeats=repeats, axis=axis)
  640. if shape != ():
  641. cases = [
  642. {"input": np.random.randn(*shape).astype("float32")},
  643. ]
  644. else:
  645. cases = [{"input": np.array(1.23)}]
  646. opr_test(
  647. cases,
  648. repeat_func,
  649. ref_fn=lambda inp: np.repeat(inp, repeats, axis),
  650. network=network,
  651. )
  652. @pytest.mark.parametrize(
  653. "shape, reps",
  654. [
  655. ((2,), (2,)),
  656. ((2, 3, 4, 5), (1, 1, 1, 1)),
  657. ((2, 3, 4, 5), (1, 2, 3, 4)),
  658. # FIXME: tile does not support ndim 7
  659. # ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
  660. ],
  661. )
  662. @pytest.mark.parametrize("is_varnode", [True])
  663. def test_tile(shape, reps, is_varnode):
  664. if is_varnode:
  665. network = Network()
  666. else:
  667. network = None
  668. def tile_func(inp):
  669. return F.tile(inp=inp, reps=reps)
  670. cases = [{"input": np.random.randn(*shape).astype("float32")}]
  671. opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network)
  672. @pytest.mark.parametrize(
  673. "shape, shifts, axis",
  674. [
  675. ((2, 3), 0, None),
  676. ((2, 3), 1, 0),
  677. ((2, 3), 100, 0),
  678. ((2, 3), -100, 0),
  679. ((2, 3, 4, 5), (-1, 1), (0, 1)),
  680. ((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)),
  681. ],
  682. )
  683. @pytest.mark.parametrize("is_varnode", [True, False])
  684. def test_roll(shape, shifts, axis, is_varnode):
  685. if is_varnode:
  686. network = Network()
  687. else:
  688. network = None
  689. inp = np.random.randn(*shape).astype("float32")
  690. def func(inp):
  691. return F.roll(inp, shifts, axis)
  692. cases = [
  693. {"input": inp},
  694. ]
  695. opr_test(
  696. cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network
  697. )
  698. @pytest.mark.parametrize(
  699. "shape, shifts, axis", [((10, 0), 5, 1), ((10, 0), -10, 1),],
  700. )
  701. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  702. def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
  703. inp = tensor(np.random.randn(*shape).astype("float32"))
  704. def func(inp):
  705. return F.roll(inp, shifts, axis)
  706. if is_symbolic is not None:
  707. func = trace(symbolic=is_symbolic)(func)
  708. out_ref = np.roll(inp.numpy(), shifts, axis)
  709. for _ in range(3):
  710. out = F.roll(inp, shifts, axis)
  711. np.testing.assert_equal(out.numpy(), out_ref)
  712. if is_symbolic is None:
  713. break