You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_functional.py 44 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359
  1. # -*- coding: utf-8 -*-
  2. import itertools
  3. import platform
  4. from functools import partial
  5. import numpy as np
  6. import pytest
  7. from utils import opr_test
  8. import megengine.amp as amp
  9. import megengine.config as config
  10. import megengine.core.ops.builtin as builtin
  11. import megengine.core.tensor.dtype as dtype
  12. import megengine.functional as F
  13. import megengine.jit as jit
  14. from megengine import Parameter, Tensor, is_cuda_available, tensor
  15. from megengine.core._trace_option import use_symbolic_shape
  16. from megengine.core.autodiff.grad import Grad
  17. from megengine.core.tensor.utils import make_shape_tuple
  18. from megengine.device import get_device_count
  19. from megengine.module import LayerNorm
  20. _assert_allclose = partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)
  21. def test_where():
  22. maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_)
  23. xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)
  24. yv0 = np.array([[5, 6], [7, 8]], dtype=np.float32)
  25. maskv1 = np.array([[1, 0, 1], [1, 0, 0], [1, 1, 0]], dtype=np.bool_)
  26. xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32)
  27. yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32)
  28. maskv2 = np.array([1, 1, 1], dtype=np.bool_)
  29. xv2 = np.array([1, 3, 2], dtype=np.float32)
  30. yv2 = np.array([5, 6, 9], dtype=np.float32)
  31. maskv3 = np.array([0, 0, 0], dtype=np.bool_)
  32. xv3 = np.array([1, 3, 2], dtype=np.float32)
  33. yv3 = np.array([5, 6, 9], dtype=np.float32)
  34. maskv4 = np.array(1, dtype=np.bool_)
  35. xv4 = np.array(1, dtype=np.float32)
  36. yv4 = np.array(0, dtype=np.float32)
  37. cases = [
  38. {"input": [maskv0, xv0, yv0]},
  39. {"input": [maskv1, xv1, yv1]},
  40. {"input": [maskv2, xv2, yv2]},
  41. {"input": [maskv3, xv3, yv3]},
  42. {"input": [maskv4, xv4, yv4]},
  43. ]
  44. opr_test(cases, F.where, ref_fn=np.where, test_trace=True)
  45. def test_dropout():
  46. from megengine.autodiff import GradManager
  47. from megengine.core._imperative_rt.ops import set_global_rng_seed
  48. def test_dropout_with_shape(shape, rate):
  49. data = tensor(np.ones(shape, dtype=np.float32))
  50. gm = GradManager().attach([data])
  51. with gm:
  52. out = F.nn.dropout(data, rate, training=True)
  53. gm.backward(out, tensor(np.ones(shape, dtype=np.float32)))
  54. if len(shape) != 0:
  55. assert not out.numpy().all()
  56. np.testing.assert_allclose(out.numpy(), data.grad.numpy(), 1e-7, 1e-7)
  57. def test_multiple_dropout(shape, rate):
  58. data = tensor(np.ones(shape, dtype=np.float32))
  59. gm = GradManager().attach([data])
  60. with gm:
  61. out1 = F.nn.dropout(data, rate, training=True)
  62. out2 = F.nn.dropout(out1, rate, training=True)
  63. out3 = F.nn.dropout(out2, rate, training=True)
  64. gm.backward(out3, tensor(np.ones(shape, dtype=np.float32)))
  65. np.testing.assert_allclose(out3.numpy(), data.grad.numpy(), 1e-7, 1e-7)
  66. def test_dropout_seed(shape, rate):
  67. data = tensor(np.random.randn(*shape), dtype="float32")
  68. set_global_rng_seed(111)
  69. out1 = F.nn.dropout(data, rate, training=True)
  70. out2 = F.nn.dropout(data, rate, training=True)
  71. assert not (out1.numpy() == out2.numpy()).all()
  72. set_global_rng_seed(111)
  73. out3 = F.nn.dropout(data, rate, training=True)
  74. assert (out1.numpy() == out3.numpy()).all()
  75. set_global_rng_seed(222)
  76. out4 = F.nn.dropout(data, rate, training=True)
  77. assert not (out1.numpy() == out4.numpy()).all()
  78. test_dropout_with_shape([], 0.4)
  79. test_dropout_with_shape([13, 17, 63, 21], 0.4)
  80. test_dropout_with_shape([16, 32, 64], 0.3)
  81. test_multiple_dropout([1024], 0.2)
  82. test_dropout_seed([16, 32], 0.2)
  83. def test_matinv():
  84. shape1 = (5, 5)
  85. shape2 = (3, 9, 9)
  86. data1 = np.random.random(shape1).astype("float32")
  87. data2 = np.random.random(shape2).astype("float32")
  88. # make matrix diagonally dominant for numerical stability
  89. data1 += (np.eye(shape1[0]) * shape1[0]).astype("float32")
  90. data2 += np.broadcast_to((np.eye(shape2[1]) * shape2[1]).astype("float32"), shape2)
  91. cases = [
  92. {"input": data1},
  93. {"input": data2},
  94. ]
  95. opr_test(
  96. cases,
  97. F.matinv,
  98. compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-4),
  99. ref_fn=np.linalg.inv,
  100. )
  101. def test_matmul():
  102. shape1 = 3
  103. shape2 = 3
  104. shape3 = (3, 5)
  105. shape4 = (5, 6)
  106. data1 = np.random.random(shape1).astype("float32")
  107. data2 = np.random.random(shape2).astype("float32")
  108. data3 = np.random.random(shape3).astype("float32")
  109. data4 = np.random.random(shape4).astype("float32")
  110. cases = [
  111. {"input": [data1, data2]},
  112. {"input": [data2, data3]},
  113. {"input": [data3, data4]},
  114. ]
  115. opr_test(cases, F.matmul, ref_fn=np.matmul)
  116. batch_size = 10
  117. shape1 = (2,)
  118. shape2 = (batch_size, 2, 3)
  119. shape3 = (batch_size, 3, 4)
  120. shape4 = (batch_size, 10, 4, 2)
  121. shape5 = (batch_size, 10, 2, 4)
  122. data1 = np.random.random(shape1).astype("float32")
  123. data2 = np.random.random(shape2).astype("float32")
  124. data3 = np.random.random(shape3).astype("float32")
  125. data4 = np.random.random(shape4).astype("float32")
  126. data5 = np.random.random(shape5).astype("float32")
  127. cases = [
  128. {"input": [data1, data2]},
  129. {"input": [data2, data3]},
  130. {"input": [data3, data4]},
  131. {"input": [data4, data5]},
  132. ]
  133. opr_test(cases, F.matmul, ref_fn=np.matmul)
  134. opr_test(
  135. [{"input": [data1, data4]}],
  136. F.matmul,
  137. ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)),
  138. transpose_b=True,
  139. )
  140. opr_test(
  141. [{"input": [data3, data2]}],
  142. F.matmul,
  143. ref_fn=lambda x, y: np.matmul(x.transpose(0, 2, 1), y.transpose(0, 2, 1)),
  144. transpose_a=True,
  145. transpose_b=True,
  146. )
  147. @pytest.mark.parametrize(
  148. "shape_a, shape_b", [((0,), (0,)), ((10, 0), (0, 10)), ((3, 10, 0), (3, 0, 10)),],
  149. )
  150. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  151. def test_matmul_empty_tensor(shape_a, shape_b, is_symbolic):
  152. def func(a, b):
  153. return F.matmul(a, b)
  154. if is_symbolic is not None:
  155. func = jit.trace(symbolic=is_symbolic)(func)
  156. a = tensor(np.random.randn(*shape_a))
  157. b = tensor(np.random.randn(*shape_b))
  158. for _ in range(3):
  159. out = func(a, b)
  160. assert np.all(out.numpy() == 0)
  161. if is_symbolic is None:
  162. break
  163. def test_interpolate():
  164. def linear_interpolate():
  165. inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
  166. test_func = lambda inp: F.vision.interpolate(
  167. inp, scale_factor=2.0, mode="linear"
  168. )
  169. ref_func = lambda inp: F.vision.interpolate(inp, 4, mode="linear").numpy()
  170. cases = [{"input": inp}]
  171. opr_test(cases, test_func, ref_fn=ref_func, test_trace=True)
  172. def many_batch_interpolate():
  173. inp = tensor(np.arange(1, 9, dtype=np.float32).reshape(2, 1, 2, 2))
  174. test_func = lambda inp: F.vision.interpolate(inp, scale_factor=2.0)
  175. ref_func = lambda inp: F.vision.interpolate(inp, [4, 4]).numpy()
  176. cases = [{"input": inp}]
  177. opr_test(cases, test_func, ref_fn=ref_func, test_trace=True)
  178. def assign_corner_interpolate():
  179. inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  180. test_func = lambda inp: F.vision.interpolate(inp, [4, 4])
  181. ref_func = lambda inp: F.vision.interpolate(inp, scale_factor=2.0).numpy()
  182. cases = [{"input": inp}]
  183. opr_test(cases, test_func, ref_fn=ref_func, test_trace=True)
  184. def error_shape_linear_interpolate():
  185. inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  186. with pytest.raises(ValueError):
  187. F.vision.interpolate(inp, scale_factor=2.0, mode="linear")
  188. def inappropriate_scale_linear_interpolate():
  189. inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
  190. with pytest.raises(ValueError):
  191. F.vision.interpolate(inp, scale_factor=[2.0, 3.0], mode="linear")
  192. linear_interpolate()
  193. many_batch_interpolate()
  194. assign_corner_interpolate()
  195. error_shape_linear_interpolate()
  196. # inappropriate_scale_linear_interpolate()
  197. def _save_to(self, name="grad"):
  198. def callback(grad):
  199. setattr(self, name, grad)
  200. return callback
  201. def _gen_roi_inp():
  202. inp_feat = np.random.randn(2, 32, 256, 256)
  203. rois = np.zeros((4, 5))
  204. rois[:, 0] = [0, 0, 1, 1]
  205. rois[:, 1:3] = np.random.rand(4, 2) * 100
  206. rois[:, 3:] = np.random.rand(4, 2) * 100 + 150
  207. inp_feat = tensor(inp_feat)
  208. rois = tensor(rois)
  209. return inp_feat, rois
  210. def test_roi_align():
  211. inp_feat, rois = _gen_roi_inp()
  212. with Grad() as grad:
  213. grad.wrt(inp_feat, callback=_save_to(inp_feat))
  214. output_shape = (7, 7)
  215. out_feat = F.vision.roi_align(
  216. inp_feat,
  217. rois,
  218. output_shape=output_shape,
  219. mode="average",
  220. spatial_scale=1.0 / 4,
  221. sample_points=2,
  222. aligned=True,
  223. )
  224. assert make_shape_tuple(out_feat.shape) == (
  225. rois.shape[0],
  226. inp_feat.shape[1],
  227. *output_shape,
  228. )
  229. grad(out_feat, tensor(F.ones_like(out_feat)))
  230. assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
  231. def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)):
  232. if random:
  233. inp_feat1 = np.random.randn(
  234. image_shape[0], image_shape[1], image_shape[2], image_shape[3]
  235. )
  236. inp_feat2 = np.random.randn(
  237. image_shape[0], image_shape[1], image_shape[2], image_shape[3]
  238. )
  239. else:
  240. inp_feat1 = np.ones(image_shape) * constant
  241. inp_feat2 = np.ones(image_shape) * constant
  242. return tensor(inp_feat1), tensor(inp_feat2)
  243. def test_correlation():
  244. ##test case 0 check the grad shape
  245. data1, data2 = _gen_correlation()
  246. with Grad() as grad:
  247. grad.wrt(data1, callback=_save_to(data1))
  248. out_feat = F.vision.correlation(
  249. data1,
  250. data2,
  251. kernel_size=5,
  252. max_displacement=4,
  253. stride1=2,
  254. stride2=2,
  255. pad_size=2,
  256. is_multiply=True,
  257. )
  258. grad(out_feat, tensor(F.ones_like(out_feat)))
  259. assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape)
  260. ##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194
  261. data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
  262. out_feat = F.vision.correlation(
  263. data1,
  264. data2,
  265. kernel_size=3,
  266. max_displacement=0,
  267. stride1=1,
  268. stride2=1,
  269. pad_size=0,
  270. is_multiply=True,
  271. )
  272. assert abs(out_feat.sum() - 1) < 1e-9
  273. ##test case 2 check same image subduction
  274. data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
  275. out_feat = F.vision.correlation(
  276. data1,
  277. data2,
  278. kernel_size=3,
  279. max_displacement=0,
  280. stride1=1,
  281. stride2=1,
  282. pad_size=0,
  283. is_multiply=False,
  284. )
  285. assert out_feat.sum() < 1e-9
  286. ##test case 3 check same image subduction
  287. data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
  288. out_feat = F.vision.correlation(
  289. data1,
  290. data2,
  291. kernel_size=3,
  292. max_displacement=0,
  293. stride1=1,
  294. stride2=1,
  295. pad_size=0,
  296. is_multiply=False,
  297. )
  298. assert out_feat.sum() < 1e-9
  299. ##test case 4 check correlation
  300. data1, _ = _gen_correlation(
  301. random=False, image_shape=(1, 1, 220, 220), constant=2.0
  302. )
  303. _, data2 = _gen_correlation(
  304. random=False, image_shape=(1, 1, 220, 220), constant=1.0
  305. )
  306. out_feat = F.vision.correlation(
  307. data1,
  308. data2,
  309. kernel_size=3,
  310. max_displacement=2,
  311. stride1=1,
  312. stride2=2,
  313. pad_size=0,
  314. is_multiply=False,
  315. )
  316. assert abs(out_feat.mean() - 1) < 1e-9
  317. def test_roi_pooling():
  318. inp_feat, rois = _gen_roi_inp()
  319. with Grad() as grad:
  320. grad.wrt(inp_feat, callback=_save_to(inp_feat))
  321. output_shape = (7, 7)
  322. out_feat = F.vision.roi_pooling(
  323. inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4,
  324. )
  325. assert make_shape_tuple(out_feat.shape) == (
  326. rois.shape[0],
  327. inp_feat.shape[1],
  328. *output_shape,
  329. )
  330. grad(out_feat, tensor(F.ones_like(out_feat)))
  331. assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
  332. def test_adaptive_avg_pool2d():
  333. inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
  334. oshp = (2, 2)
  335. with Grad() as grad:
  336. grad.wrt(inp, callback=_save_to(inp))
  337. outp = F.adaptive_avg_pool2d(inp, oshp,)
  338. assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
  339. np.testing.assert_equal(
  340. outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32)
  341. )
  342. grad(outp, tensor(F.ones_like(outp)))
  343. assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
  344. np.testing.assert_equal(
  345. inp.grad.numpy(),
  346. np.array(
  347. [
  348. [
  349. [
  350. [0.25, 0.25, 0.25, 0.25],
  351. [0.25, 0.25, 0.25, 0.25],
  352. [0.25, 0.25, 0.25, 0.25],
  353. [0.25, 0.25, 0.25, 0.25],
  354. ]
  355. ]
  356. ],
  357. dtype=np.float32,
  358. ),
  359. )
  360. def test_adaptive_max_pool2d():
  361. inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
  362. oshp = (2, 2)
  363. with Grad() as grad:
  364. grad.wrt(inp, callback=_save_to(inp))
  365. outp = F.adaptive_max_pool2d(inp, oshp,)
  366. assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
  367. np.testing.assert_equal(
  368. outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32)
  369. )
  370. grad(outp, tensor(F.ones_like(outp)))
  371. assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
  372. np.testing.assert_equal(
  373. inp.grad.numpy(),
  374. np.array(
  375. [
  376. [
  377. [
  378. [0.0, 0.0, 0.0, 0.0],
  379. [0.0, 1.0, 0.0, 1.0],
  380. [0.0, 0.0, 0.0, 0.0],
  381. [0.0, 1.0, 0.0, 1.0],
  382. ]
  383. ]
  384. ],
  385. dtype=np.float32,
  386. ),
  387. )
  388. def test_one_hot():
  389. def onehot_low_dimension():
  390. inp = tensor(np.arange(1, 4, dtype=np.int32))
  391. out = F.one_hot(inp, num_classes=4)
  392. np.testing.assert_allclose(
  393. out.numpy(), np.eye(4, dtype=np.int32)[np.arange(1, 4, dtype=np.int32)]
  394. )
  395. def onehot_high_dimension():
  396. arr = np.array(
  397. [[3, 2, 4, 4, 2, 4, 0, 4, 4, 1], [4, 1, 1, 3, 2, 2, 4, 2, 4, 3]],
  398. dtype=np.int32,
  399. )
  400. inp = tensor(arr)
  401. out = F.one_hot(inp, 10)
  402. np.testing.assert_allclose(out.numpy(), np.eye(10, dtype=np.int32)[arr])
  403. onehot_low_dimension()
  404. onehot_high_dimension()
  405. def test_interpolate_fastpath():
  406. # check shape
  407. test_cases = [
  408. [(1, 1, 10, 10), (5, 5)],
  409. [(1, 3, 10, 10), (20, 20)],
  410. [(10, 1, 10, 10), (1, 1)],
  411. [(10, 10, 1, 1), (10, 10)],
  412. ]
  413. for inp_shape, target_shape in test_cases:
  414. x = tensor(np.random.randn(*inp_shape), dtype=np.float32)
  415. out = F.vision.interpolate(x, target_shape, mode="bilinear")
  416. assert out.shape[0] == x.shape[0] and out.shape[1] == x.shape[1]
  417. assert out.shape[2] == target_shape[0] and out.shape[3] == target_shape[1]
  418. # check value
  419. x = tensor(np.ones((3, 3, 10, 10)), dtype=np.float32)
  420. out = F.vision.interpolate(x, (15, 5), mode="bilinear")
  421. np.testing.assert_equal(out.numpy(), np.ones((3, 3, 15, 5)).astype(np.float32))
  422. np_x = np.arange(32)
  423. x = tensor(np_x).astype(np.float32).reshape(1, 1, 32, 1)
  424. out = F.vision.interpolate(x, (1, 1), mode="bilinear")
  425. np.testing.assert_equal(out.item(), np_x.mean())
  426. @pytest.mark.parametrize("dt", [np.float32, np.int8, np.uint8, np.float16])
  427. def test_warp_perspective(dt):
  428. inp_shape = (1, 1, 4, 4)
  429. x = tensor(np.arange(16, dtype=dt).reshape(inp_shape))
  430. M_shape = (1, 3, 3)
  431. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  432. M = tensor(
  433. np.array(
  434. [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
  435. ).reshape(M_shape)
  436. )
  437. outp = F.vision.warp_perspective(x, M, (2, 2))
  438. np.testing.assert_equal(outp.numpy(), np.array([[[[5, 6], [9, 10]]]], dtype=dt))
  439. @pytest.mark.parametrize("dt", [np.float32, np.int8, np.uint8, np.float16])
  440. def test_warp_perspective_mat_idx(dt):
  441. inp_shape = (2, 1, 4, 4)
  442. x = tensor(np.arange(32, dtype=dt).reshape(inp_shape))
  443. M_shape = (1, 3, 3)
  444. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  445. M = tensor(
  446. np.array(
  447. [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
  448. ).reshape(M_shape)
  449. )
  450. M = F.concat([M,] * 4, 0)
  451. outp = F.vision.warp_perspective(x, M, (2, 2), mat_idx=[0, 1, 1, 0])
  452. np.testing.assert_equal(
  453. outp.numpy(),
  454. np.array(
  455. [
  456. [[[5, 6], [9, 10]]],
  457. [[[21, 22], [25, 26]]],
  458. [[[21, 22], [25, 26]]],
  459. [[[5, 6], [9, 10]]],
  460. ],
  461. dtype=dt,
  462. ),
  463. )
  464. def test_warp_affine():
  465. inp_shape = (1, 3, 3, 3)
  466. x = tensor(np.arange(27, dtype=np.float32).reshape(inp_shape))
  467. weightv = [[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]]
  468. outp = F.vision.warp_affine(x, tensor(weightv), (2, 2), border_mode="wrap")
  469. res = np.array(
  470. [
  471. [
  472. [[7.875, 8.875, 9.875], [8.90625, 9.90625, 10.90625]],
  473. [[18.75, 19.75, 20.75], [14.90625, 15.90625, 16.90625]],
  474. ]
  475. ],
  476. dtype=np.float32,
  477. )
  478. if not is_cuda_available():
  479. np.testing.assert_almost_equal(outp.numpy(), res, 5)
  480. def test_remap():
  481. inp_shape = (1, 1, 4, 4)
  482. inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  483. map_xy_shape = (1, 2, 2, 2)
  484. map_xy = tensor(
  485. np.array(
  486. [[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], dtype=np.float32
  487. ).reshape(map_xy_shape)
  488. )
  489. outp = F.vision.remap(inp, map_xy)
  490. np.testing.assert_equal(
  491. outp.numpy(), np.array([[[[1.0, 4.0], [4.0, 4.0]]]], dtype=np.float32)
  492. )
  493. def test_binary_cross_entropy():
  494. data1_shape = (2, 2)
  495. label1_shape = (2, 2)
  496. data2_shape = (2, 3)
  497. label2_shape = (2, 3)
  498. def sigmoid(x):
  499. return 1 / (1 + np.exp(-x))
  500. def compare_fn(x, y):
  501. np.testing.assert_allclose(x.numpy(), y, atol=5e-4)
  502. np.random.seed(123)
  503. data1 = np.random.uniform(size=data1_shape).astype(np.float32)
  504. label1 = np.random.uniform(size=label1_shape).astype(np.float32)
  505. expect1 = np.array(0.6361, dtype=np.float32)
  506. np.random.seed(123)
  507. data2 = np.random.uniform(size=data2_shape).astype(np.float32)
  508. label2 = np.random.uniform(size=label2_shape).astype(np.float32)
  509. expect2 = np.array(0.6750, dtype=np.float32)
  510. cases = [
  511. {"input": [data1, label1], "output": expect1,},
  512. {"input": [data2, label2], "output": expect2,},
  513. ]
  514. opr_test(cases, F.nn.binary_cross_entropy, compare_fn=compare_fn)
  515. cases = [
  516. {"input": [sigmoid(data1), label1], "output": expect1,},
  517. {"input": [sigmoid(data2), label2], "output": expect2,},
  518. ]
  519. opr_test(
  520. cases,
  521. partial(F.nn.binary_cross_entropy, with_logits=False),
  522. compare_fn=compare_fn,
  523. )
  524. def test_hinge_loss():
  525. np.random.seed(123)
  526. # case with L1 norm
  527. cases = []
  528. for shape in [(2, 2), (2, 3)]:
  529. data = np.random.uniform(size=shape).astype(np.float32)
  530. label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
  531. expect = np.clip(0, np.inf, 1 - data * label).sum(axis=1).mean()
  532. cases.append({"input": [data, label], "output": expect})
  533. opr_test(cases, F.nn.hinge_loss)
  534. # cases with L2 norm
  535. cases = []
  536. for shape in [(2, 2), (2, 3)]:
  537. data = np.random.uniform(size=shape).astype(np.float32)
  538. label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
  539. expect = ((np.clip(0, np.inf, 1 - data * label) ** 2).sum(axis=1)).mean()
  540. cases.append({"input": [data, label], "output": expect})
  541. def hinge_loss_with_l2_norm(pred, label):
  542. return F.nn.hinge_loss(pred, label, "L2")
  543. opr_test(cases, hinge_loss_with_l2_norm)
  544. @pytest.mark.parametrize("is_symbolic", [None, False, True])
  545. def test_nms(is_symbolic):
  546. def fn(inp, scores):
  547. return F.vision.nms(
  548. inp,
  549. scores=scores,
  550. iou_thresh=0.5,
  551. max_output=None if is_symbolic is None else 4,
  552. )
  553. if is_symbolic is not None:
  554. fn = jit.trace(symbolic=is_symbolic)(fn)
  555. x = np.array(
  556. [
  557. [0, 0, 100, 100],
  558. [10, 10, 100, 100],
  559. [50, 50, 100, 100],
  560. [100, 100, 150, 150],
  561. ],
  562. dtype=np.float32,
  563. )
  564. inp = tensor(x)
  565. scores = tensor([0.5, 0.8, 0.9, 0.6], dtype=np.float32)
  566. for _ in range(3):
  567. result = fn(inp, scores=scores)
  568. np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32))
  569. x = np.array([], dtype=np.float32,).reshape(0, 4)
  570. inp = tensor(x)
  571. scores = tensor([], dtype=np.float32)
  572. for _ in range(3):
  573. result = fn(inp, scores=scores)
  574. np.testing.assert_equal(result.numpy(), np.array([], dtype=np.int32))
  575. @pytest.mark.skipif(
  576. get_device_count("gpu") > 0, reason="cuda does not support nchw int8"
  577. )
  578. def test_conv_bias():
  579. inp_scale = 1.5
  580. w_scale = 2.5
  581. outp_scale = 1.5
  582. inp_dtype = dtype.qint8(inp_scale)
  583. w_dtype = dtype.qint8(w_scale)
  584. b_dtype = dtype.qint32(inp_scale * w_scale)
  585. out_dtype = dtype.qint8(outp_scale)
  586. def run(
  587. N,
  588. IC,
  589. OC,
  590. IH,
  591. IW,
  592. KH,
  593. KW,
  594. PH,
  595. PW,
  596. SH,
  597. SW,
  598. has_bias=True,
  599. nonlinear_mode="identity",
  600. ):
  601. inp_v = np.random.normal(size=(N, IC, IH, IW))
  602. w_v = np.random.normal(size=(OC, IC, KH, KW))
  603. b_v = np.random.normal(size=(1, OC, 1, 1))
  604. inp_scale = dtype.get_scale(inp_dtype)
  605. w_scale = dtype.get_scale(w_dtype)
  606. b_scale = dtype.get_scale(b_dtype)
  607. inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
  608. wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
  609. bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
  610. inp_int8 = tensor(inpv, dtype=inp_dtype)
  611. w_int8 = Parameter(wv, dtype=w_dtype)
  612. b_int32 = Parameter(bv, dtype=b_dtype)
  613. inp_fp32 = inp_int8.astype("float32")
  614. w_fp32 = w_int8.astype("float32")
  615. b_fp32 = b_int32.astype("float32")
  616. def convert_to_nchw4(var):
  617. var = F.reshape(
  618. var, (var.shape[0], var.shape[1] // 4, 4, var.shape[2], var.shape[3])
  619. )
  620. var = F.transpose(var, (0, 1, 3, 4, 2))
  621. return var
  622. def run_conv2d(inp, w, b):
  623. O = F.conv2d(
  624. inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
  625. )
  626. if nonlinear_mode == "relu":
  627. return F.relu(O)
  628. else:
  629. return O
  630. def run_conv_bias(inp, w, b, format="NCHW"):
  631. b = b if has_bias else Parameter(np.zeros_like(b.numpy()))
  632. if format == "NCHW4":
  633. inp = convert_to_nchw4(inp)
  634. w = convert_to_nchw4(w)
  635. b = convert_to_nchw4(b)
  636. return F.quantized.conv_bias_activation(
  637. inp,
  638. w,
  639. b,
  640. stride=(SH, SW),
  641. padding=(PH, PW),
  642. dtype=out_dtype,
  643. nonlinear_mode=nonlinear_mode,
  644. )
  645. format = "NCHW4" if is_cuda_available() else "NCHW"
  646. expected = run_conv2d(inp_fp32, w_fp32, b_fp32)
  647. expected = expected.astype(out_dtype).astype("float32")
  648. result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype(
  649. "float32"
  650. )
  651. if format == "NCHW4":
  652. result = F.transpose(result, (0, 1, 4, 2, 3))
  653. expected = F.flatten(expected)
  654. result = F.flatten(result)
  655. np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
  656. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False)
  657. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False)
  658. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False)
  659. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1)
  660. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
  661. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)
  662. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "relu")
  663. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu")
  664. @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda")
  665. def test_batch_conv_bias():
  666. inp_scale = 1.5
  667. w_scale = 2.5
  668. outp_scale = 1.5
  669. inp_dtype = dtype.qint8(inp_scale)
  670. w_dtype = dtype.qint8(w_scale)
  671. b_dtype = dtype.qint32(inp_scale * w_scale)
  672. out_dtype = dtype.qint8(outp_scale)
  673. def run(
  674. N, IC, OC, IH, IW, KH, KW, PH, PW, SH, SW, has_bias=True,
  675. ):
  676. inp_v = np.random.normal(size=(N, IC, IH, IW))
  677. w_v = np.random.normal(size=(N, OC, IC, KH, KW))
  678. b_v = np.random.normal(size=(1, OC, 1, 1))
  679. inp_scale = dtype.get_scale(inp_dtype)
  680. w_scale = dtype.get_scale(w_dtype)
  681. b_scale = dtype.get_scale(b_dtype)
  682. inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
  683. wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
  684. bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
  685. inp_int8 = tensor(inpv, dtype=inp_dtype)
  686. w_int8 = Parameter(wv, dtype=w_dtype)
  687. b_int32 = Parameter(bv, dtype=b_dtype)
  688. inp_fp32 = inp_int8.astype("float32")
  689. w_fp32 = w_int8.astype("float32")
  690. b_fp32 = b_int32.astype("float32")
  691. def run_batch_conv_bias(inp, w, b):
  692. b = b if has_bias else Parameter(np.zeros_like(b.numpy()))
  693. result = F.quantized.batch_conv_bias_activation(
  694. inp, w, b, stride=(SH, SW), padding=(PH, PW), dtype=out_dtype,
  695. )
  696. return result.astype("float32")
  697. expected = F.conv2d(inp_fp32, w_fp32[0], b_fp32 if has_bias else None)[0]
  698. expected = expected.astype(out_dtype).astype("float32")
  699. expected = F.flatten(expected)
  700. result = run_batch_conv_bias(inp_int8, w_int8, b_int32)
  701. result = F.flatten(result)
  702. np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
  703. run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True)
  704. def test_conv2d_autocast():
  705. """check amp's result is equal to manually converted result"""
  706. amp.enabled = True
  707. inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32)
  708. weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32)
  709. out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  710. amp.enabled = False
  711. expected = F.conv2d(
  712. inp.astype("float16"),
  713. weight.astype("float16"),
  714. None,
  715. (2, 2),
  716. (3, 3),
  717. (1, 1),
  718. 1,
  719. compute_mode="float32",
  720. )
  721. assert out.dtype == np.float16
  722. assert expected.dtype == np.float16
  723. np.testing.assert_allclose(out.numpy(), expected.numpy())
  724. def test_conv2d_zero_stride_numpy_array():
  725. inp = np.random.randn(3, 224, 224).astype(np.float32)
  726. inp = inp[np.newaxis, :]
  727. inp = tensor(inp, dtype=np.float32)
  728. weight = tensor(np.random.randn(16, 3, 3, 3), dtype=np.float32)
  729. out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  730. def test_conv3d_zero_stride_numpy_array():
  731. inp = np.random.randn(3, 224, 224, 224).astype(np.float32)
  732. inp = inp[np.newaxis, :]
  733. inp = tensor(inp, dtype=np.float32)
  734. weight = tensor(np.random.randn(16, 3, 3, 3, 3), dtype=np.float32)
  735. out = F.conv3d(inp, weight, None, (2, 2, 2), (3, 3, 3), (1, 1, 1), 1)
  736. out.numpy()
  737. def test_conv1d():
  738. inp = tensor(np.ones((2, 2, 4), dtype=np.float32))
  739. weight = tensor(np.ones((3, 2, 2), dtype=np.float32))
  740. out = F.conv1d(inp, weight, None, 2, 0, 1, 1)
  741. np.testing.assert_equal(
  742. out.numpy(),
  743. np.array(
  744. [[[4, 4], [4, 4], [4, 4]], [[4, 4], [4, 4], [4, 4]]], dtype=np.float32
  745. ),
  746. )
  747. def test_batchnorm2d_autocast():
  748. """check amp's result is equal to manually converted result"""
  749. amp.enabled = True
  750. tshape = (1, 3, 224, 224)
  751. pshape = (1, 3, 1, 1)
  752. inp = tensor(np.random.randn(*tshape), dtype=np.float32)
  753. weight = tensor(np.ones(pshape, dtype=np.float32))
  754. bias = tensor(np.zeros(pshape, dtype=np.float32))
  755. out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False)
  756. amp.enabled = False
  757. expected = F.batch_norm(
  758. inp.astype("float16"),
  759. weight=weight,
  760. bias=bias,
  761. training=True,
  762. inplace=False,
  763. compute_mode="float32",
  764. )
  765. assert out.dtype == np.float16
  766. assert expected.dtype == np.float16
  767. np.testing.assert_allclose(out.numpy(), expected.numpy())
  768. def test_conv3d():
  769. inp = tensor(np.ones((2, 2, 4, 4, 4), dtype=np.float32))
  770. weight = tensor(np.ones((3, 2, 2, 2, 2), dtype=np.float32))
  771. out = F.conv3d(inp, weight, None, 2, 0, 1, 1)
  772. np.testing.assert_equal(
  773. out.numpy(), np.ones((2, 3, 2, 2, 2), dtype=np.float32) * 16
  774. )
  775. def test_condtake():
  776. x = np.array([[1, 2, 3], [4, 5, 6]])
  777. y = np.array([[True, False, True], [False, True, True]])
  778. xx = tensor(x)
  779. yy = tensor(y)
  780. val, idx = F.cond_take(yy, xx)
  781. np.testing.assert_equal(val.numpy(), x[y])
  782. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  783. @pytest.mark.parametrize("is_symbolic", [None, False, True])
  784. def test_condtake(is_symbolic):
  785. shapes = [
  786. (3, 3, 3),
  787. (0,),
  788. (3, 0, 3),
  789. ]
  790. def fn(mask, data):
  791. return F.cond_take(mask, data)
  792. if is_symbolic is not None:
  793. fn = jit.trace(symbolic=is_symbolic)(fn)
  794. for shp in shapes:
  795. x_np = np.random.randn(*shp).astype("float32")
  796. mask_np = x_np > 0
  797. x = tensor(x_np)
  798. mask = tensor(mask_np)
  799. ref_out = x_np[mask_np]
  800. ref_idx = mask_np.flatten().nonzero()[0]
  801. for i in range(3):
  802. out, idx = fn(mask, x)
  803. np.testing.assert_equal(out.numpy(), ref_out)
  804. np.testing.assert_equal(idx.numpy(), ref_idx)
  805. if is_symbolic is None:
  806. break
  807. def test_condtake_is_same():
  808. op1 = builtin.CondTake()
  809. op2 = builtin.CondTake()
  810. assert op1 == op2
  811. def test_nms_is_same():
  812. op1 = builtin.NMSKeep(0.7, 100)
  813. op2 = builtin.NMSKeep(0.7, 100)
  814. op3 = builtin.NMSKeep(0.8, 100)
  815. op4 = builtin.NMSKeep(0.7, 200)
  816. assert op1 == op2
  817. assert op1 != op3
  818. assert op1 != op4
  819. assert op3 != op4
  820. def test_argmxx_on_inf():
  821. def run_argmax():
  822. x = F.zeros((100, 100))
  823. x[:] = -float("inf")
  824. idxs = F.argmax(x, axis=0)
  825. return idxs
  826. def run_argmin():
  827. x = F.zeros((100, 100))
  828. x[:] = float("inf")
  829. idxs = F.argmin(x, axis=0)
  830. return idxs
  831. assert all(run_argmax() >= 0)
  832. assert all(run_argmin() >= 0)
  833. def test_deformable_psroi_pooling():
  834. inp = np.random.random((1, 256, 64, 64)).astype("float32")
  835. rois = np.random.random((1, 5)).astype("float32")
  836. trans = np.random.random((24, 2, 7, 7)).astype("float32")
  837. pooled_h = 7
  838. pooled_w = 7
  839. sample_per_part = 4
  840. no_trans = False
  841. part_size = 7
  842. spatial_scale = 1.0 / 64
  843. trans_std = 0.1
  844. y = F.deformable_psroi_pooling(
  845. tensor(inp),
  846. tensor(rois),
  847. tensor(trans),
  848. no_trans,
  849. part_size,
  850. pooled_h,
  851. pooled_w,
  852. sample_per_part,
  853. spatial_scale,
  854. trans_std,
  855. )
  856. def test_cvt_color():
  857. def rgb2gray(rgb):
  858. return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])
  859. def bgr2gray(bgr):
  860. return np.dot(bgr[..., :3], [0.114, 0.587, 0.299])
  861. inp = np.random.randn(3, 3, 3, 3).astype(np.float32)
  862. out = np.expand_dims(rgb2gray(inp), 3).astype(np.float32)
  863. x = tensor(inp)
  864. y = F.vision.cvt_color(x, mode="RGB2GRAY")
  865. np.testing.assert_allclose(y.numpy(), out, atol=1e-5)
  866. out1 = np.expand_dims(bgr2gray(inp), 3).astype(np.float32)
  867. y1 = F.vision.cvt_color(x, mode="BGR2GRAY")
  868. np.testing.assert_allclose(y1.numpy(), out1, atol=1e-5)
  869. @pytest.mark.parametrize("val", [2, [2,], [2, 3]])
  870. def test_ones(val):
  871. shp = tensor(val)
  872. np_shp = np.array(val)
  873. np.testing.assert_equal(F.ones(shp), np.ones(np_shp))
  874. def test_assert_equal():
  875. shape = (2, 3, 4, 5)
  876. x = F.ones(shape, dtype=np.float32)
  877. y = F.zeros(shape, dtype=np.float32) + 1.00001
  878. z = F.utils._assert_equal(x, y)
  879. def test_assert_not_equal():
  880. shape = (2, 3, 4, 5)
  881. x = F.ones(shape, dtype=np.float32)
  882. y = F.zeros(shape, dtype=np.float32) + 1.1
  883. with pytest.raises(RuntimeError):
  884. z = F.utils._assert_equal(x, y)
  885. def test_neg_axis():
  886. x = tensor(np.random.normal(0, 1, (32, 5)))
  887. y = F.argmax(x, axis=-1)
  888. yy = F.argmax(x, axis=1)
  889. np.testing.assert_equal(y.numpy(), yy.numpy())
  890. y = F.argmax(x, axis=(-1, -2))
  891. yy = F.argmax(x, axis=(0, 1))
  892. np.testing.assert_equal(y.numpy(), yy.numpy())
  893. y = F.argmin(x, axis=(-1, -2))
  894. yy = F.argmin(x, axis=(0, 1))
  895. np.testing.assert_equal(y.numpy(), yy.numpy())
  896. def test_sliding_window():
  897. N, C, H, W = 2, 3, 7, 8
  898. inp = np.random.normal(size=(N, C, H, W))
  899. ph, pw = 1, 2
  900. sh, sw = 2, 1
  901. wh, ww = 3, 2
  902. dh, dw = 1, 3
  903. s = lambda i, p, s, d, w: (i + p * 2 - (w - 1) * d - 1) // s + 1
  904. inp_pad = np.zeros((N, C, H + ph * 2, W + pw * 2))
  905. inp_pad[:, :, ph : H + ph, pw : W + pw] = inp
  906. gt_out = np.empty(
  907. (N, C, s(H, ph, sh, dh, wh), s(W, pw, sw, dw, ww), wh, ww), dtype=np.float32
  908. )
  909. for n, c, oh, ow in itertools.product(*map(range, gt_out.shape[:4])):
  910. ih, iw = oh * sh, ow * sw
  911. gt_out[n, c, oh, ow, :] = inp_pad[
  912. n, c, ih : ih + (wh - 1) * dh + 1 : dh, iw : iw + (ww - 1) * dw + 1 : dw
  913. ]
  914. out = F.sliding_window(
  915. tensor(inp), (wh, ww), padding=(ph, pw), stride=(sh, sw), dilation=(dh, dw)
  916. )
  917. np.testing.assert_equal(gt_out, out.numpy())
  918. def test_sliding_window_transpose():
  919. N, C, H, W = 2, 3, 7, 8
  920. ph, pw = 1, 2
  921. sh, sw = 2, 1
  922. wh, ww = 3, 2
  923. dh, dw = 1, 3
  924. s = lambda i, p, s, d, w: (i + p * 2 - (w - 1) * d - 1) // s + 1
  925. inp = np.random.normal(
  926. size=(N, C, s(H, ph, sh, dh, wh), s(W, pw, sw, dw, ww), wh, ww)
  927. ).astype(np.float32)
  928. gt_out = np.zeros((N, C, H, W), dtype=np.float32)
  929. for n, c in itertools.product(*map(range, inp.shape[:2])):
  930. oh = 0
  931. for ih in range(-ph, H + ph - dh * (wh - 1), sh):
  932. ow = 0
  933. for iw in range(-pw, W + pw - dw * (ww - 1), sw):
  934. for kh, kw in itertools.product(*map(range, inp.shape[-2:])):
  935. ih2 = ih + dh * kh
  936. iw2 = iw + dw * kw
  937. if ih2 >= 0 and ih2 < H and iw2 >= 0 and iw2 < W:
  938. gt_out[n, c, ih2, iw2] += inp[n, c, oh, ow, kh, kw]
  939. ow += 1
  940. oh += 1
  941. out = F.sliding_window_transpose(
  942. tensor(inp),
  943. (H, W),
  944. (wh, ww),
  945. padding=(ph, pw),
  946. stride=(sh, sw),
  947. dilation=(dh, dw),
  948. )
  949. np.testing.assert_equal(gt_out, out.numpy())
  950. def test_pad():
  951. src = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
  952. dst = np.pad(src, ((2, 2), (2, 2)), "constant")
  953. res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "CONSTANT")
  954. np.testing.assert_allclose(res, dst, atol=1e-5)
  955. dst = np.pad(src, ((2, 2), (2, 2)), "constant", constant_values=3)
  956. res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "CONSTANT", constant_value=3)
  957. np.testing.assert_allclose(res, dst, atol=1e-5)
  958. dst = np.pad(src, ((2, 2), (2, 2)), "edge")
  959. res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "EDGE")
  960. np.testing.assert_allclose(res, dst, atol=1e-5)
  961. dst = np.pad(src, ((2, 2), (2, 2)), "reflect")
  962. res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "REFLECT")
  963. np.testing.assert_allclose(res, dst, atol=1e-5)
  964. def pixel_shuffle(data, r):
  965. high_dim = data.shape[:-3]
  966. data = data.reshape(-1, data.shape[-3], data.shape[-2], data.shape[-1])
  967. inn, ic, ih, iw = data.shape
  968. res = np.zeros((inn, int(ic / (r * r)), ih * r, iw * r))
  969. for n in range(inn):
  970. for c in range(ic):
  971. for h in range(ih):
  972. for w in range(iw):
  973. res[
  974. n,
  975. int(c / r / r),
  976. h * r + int((c % (r * r)) / r),
  977. w * r + c % r,
  978. ] = data[n, c, h, w]
  979. if len(high_dim) > 0:
  980. res = res.reshape((*high_dim, int(ic / r / r), ih * r, iw * r))
  981. else:
  982. res = res[0]
  983. return res
  984. def test_pixel_shuffle():
  985. # ndim = 3
  986. inp = np.arange(16 * 3 * 3).reshape(16, 3, 3)
  987. out = F.pixel_shuffle(tensor(inp), upscale_factor=4)
  988. golden = pixel_shuffle(inp, 4)
  989. np.testing.assert_equal(out.numpy(), golden)
  990. inp_float = np.float32(inp)
  991. out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2)
  992. golden = pixel_shuffle(inp_float, 2)
  993. np.testing.assert_equal(out.numpy(), golden)
  994. # ndim = 4
  995. inp = np.arange(3 * 18 * 3 * 3).reshape(3, 18, 3, 3)
  996. out = F.pixel_shuffle(tensor(inp), upscale_factor=3)
  997. golden = pixel_shuffle(inp, 3)
  998. np.testing.assert_equal(out.numpy(), golden)
  999. inp_float = np.float32(inp)
  1000. out = F.pixel_shuffle(tensor(inp_float), upscale_factor=3)
  1001. golden = pixel_shuffle(inp_float, 3)
  1002. np.testing.assert_equal(out.numpy(), golden)
  1003. # ndim = 5
  1004. inp = np.arange(5 * 3 * 20 * 3 * 4).reshape(5, 3, 20, 3, 4)
  1005. out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
  1006. golden = pixel_shuffle(inp, 2)
  1007. np.testing.assert_equal(out.numpy(), golden)
  1008. inp_float = np.float32(inp)
  1009. out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2)
  1010. golden = pixel_shuffle(inp_float, 2)
  1011. np.testing.assert_equal(out.numpy(), golden)
  1012. # ndim = 6
  1013. inp = np.arange(6 * 5 * 3 * 25 * 3 * 4).reshape(6, 5, 3, 25, 3, 4)
  1014. out = F.pixel_shuffle(tensor(inp), upscale_factor=5)
  1015. golden = pixel_shuffle(inp, 5)
  1016. np.testing.assert_equal(out.numpy(), golden)
  1017. inp_float = np.float32(inp)
  1018. out = F.pixel_shuffle(tensor(inp_float), upscale_factor=5)
  1019. golden = pixel_shuffle(inp_float, 5)
  1020. np.testing.assert_equal(out.numpy(), golden)
  1021. # ndim = 7
  1022. inp = np.arange(2 * 3 * 5 * 3 * 20 * 3 * 4).reshape(2, 3, 5, 3, 20, 3, 4)
  1023. out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
  1024. golden = pixel_shuffle(inp, 2)
  1025. np.testing.assert_equal(out.numpy(), golden)
  1026. inp_float = np.float32(inp)
  1027. out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2)
  1028. golden = pixel_shuffle(inp_float, 2)
  1029. np.testing.assert_equal(out.numpy(), golden)
  1030. @pytest.mark.parametrize("type", ["int32", "float32"])
  1031. @pytest.mark.parametrize("is_symbolic", [False, True])
  1032. def test_pixel_shuffle_symbolic(is_symbolic, type):
  1033. def fn(inp, upscale_factor):
  1034. return F.pixel_shuffle(inp, upscale_factor=upscale_factor)
  1035. if is_symbolic is not None:
  1036. fn = jit.trace(symbolic=is_symbolic)(fn)
  1037. inp = tensor(np.arange(3 * 4 * 5 * 5).reshape(3, 4, 5, 5).astype(type))
  1038. golden = pixel_shuffle(inp, 2)
  1039. for _ in range(3):
  1040. out = fn(inp, 2)
  1041. np.testing.assert_equal(out.numpy(), golden)
  1042. if is_symbolic is None:
  1043. break
  1044. def test_set_conv2d_config():
  1045. """check setting config by contextmanager is equal to manually converted result"""
  1046. config._compute_mode = "float32"
  1047. inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float16)
  1048. weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float16)
  1049. config_out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  1050. config._compute_mode = "default"
  1051. with config._override(compute_mode="float32"):
  1052. context_out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  1053. expected = F.conv2d(
  1054. inp, weight, None, (2, 2), (3, 3), (1, 1), 1, compute_mode="float32",
  1055. )
  1056. np.testing.assert_allclose(config_out.numpy(), expected.numpy())
  1057. np.testing.assert_allclose(context_out.numpy(), expected.numpy())
  1058. def test_set_warp_perspective_config():
  1059. config._conv_format = "NHWC"
  1060. inp_shape = (1, 1, 4, 4)
  1061. inp = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  1062. M_shape = (1, 3, 3)
  1063. M = Tensor(np.random.randn(3, 3), dtype=np.float32).reshape(M_shape)
  1064. config_out = F.vision.warp_perspective(inp, M, (2, 2))
  1065. config._conv_format = "default"
  1066. with config._override(conv_format="NHWC"):
  1067. context_out = F.vision.warp_perspective(inp, M, (2, 2))
  1068. expected = F.vision.warp_perspective(inp, M, (2, 2), format="NHWC")
  1069. np.testing.assert_allclose(config_out.numpy(), expected.numpy())
  1070. np.testing.assert_allclose(context_out.numpy(), expected.numpy())
  1071. @pytest.mark.parametrize("stride", [(1, 1)])
  1072. @pytest.mark.parametrize("padding", [(1, 1)])
  1073. @pytest.mark.parametrize("dilation", [(1, 1)])
  1074. @pytest.mark.parametrize("ksize", [(3, 3)])
  1075. @pytest.mark.parametrize("groups", [1, 2])
  1076. def test_local_conv2d(stride, padding, dilation, ksize, groups):
  1077. batch_size, in_channels, out_channels = 2, 4, 8
  1078. input_height, input_width = 10, 10
  1079. output_height = (input_height + padding[0] * 2 - ksize[0]) // stride[0] + 1
  1080. output_width = (input_width + padding[1] * 2 - ksize[1]) // stride[1] + 1
  1081. def local_conv2d_np(data, weight, stride, padding, dialtion):
  1082. # naive calculation use numpy
  1083. # only test output_height == input_height, output_width == input_width
  1084. data = np.pad(data, ((0, 0), (0, 0), (1, 1), (1, 1)))
  1085. expected = np.zeros(
  1086. (batch_size, out_channels, output_height, output_width), dtype=np.float32,
  1087. )
  1088. ic_group_size = in_channels // groups
  1089. oc_group_size = out_channels // groups
  1090. for n, oc, oh, ow in itertools.product(
  1091. *map(range, [batch_size, out_channels, output_height, output_width])
  1092. ):
  1093. ih, iw = oh * stride[0], ow * stride[1]
  1094. g_id = oc // oc_group_size
  1095. expected[n, oc, ih, iw] = np.sum(
  1096. data[
  1097. n,
  1098. g_id * ic_group_size : (g_id + 1) * ic_group_size,
  1099. ih : ih + ksize[0],
  1100. iw : iw + ksize[1],
  1101. ]
  1102. * weight[g_id, oh, ow, :, :, :, oc % oc_group_size]
  1103. )
  1104. return expected
  1105. data = np.random.rand(batch_size, in_channels, input_height, input_width).astype(
  1106. "float32"
  1107. )
  1108. weight = np.random.rand(
  1109. groups,
  1110. output_height,
  1111. output_width,
  1112. in_channels // groups,
  1113. *ksize,
  1114. out_channels // groups,
  1115. ).astype("float32")
  1116. output = F.local_conv2d(
  1117. tensor(data),
  1118. tensor(weight),
  1119. None,
  1120. stride=stride,
  1121. padding=padding,
  1122. dilation=dilation,
  1123. )
  1124. ref = local_conv2d_np(data, weight, stride, padding, dilation)
  1125. np.testing.assert_almost_equal(output.numpy(), ref, 5)