You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_cifarop.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test Cifar10 and Cifar100 dataset operators
  17. """
  18. import os
  19. import pytest
  20. import numpy as np
  21. import matplotlib.pyplot as plt
  22. import mindspore.dataset as ds
  23. from mindspore import log as logger
  24. DATA_DIR_10 = "../data/dataset/testCifar10Data"
  25. DATA_DIR_100 = "../data/dataset/testCifar100Data"
  26. NO_BIN_DIR = "../data/dataset/testMnistData"
  27. def load_cifar(path, kind="cifar10"):
  28. """
  29. load Cifar10/100 data
  30. """
  31. raw = np.empty(0, dtype=np.uint8)
  32. for file_name in os.listdir(path):
  33. if file_name.endswith(".bin"):
  34. with open(os.path.join(path, file_name), mode='rb') as file:
  35. raw = np.append(raw, np.fromfile(file, dtype=np.uint8), axis=0)
  36. if kind == "cifar10":
  37. raw = raw.reshape(-1, 3073)
  38. labels = raw[:, 0]
  39. images = raw[:, 1:]
  40. elif kind == "cifar100":
  41. raw = raw.reshape(-1, 3074)
  42. labels = raw[:, :2]
  43. images = raw[:, 2:]
  44. else:
  45. raise ValueError("Invalid parameter value")
  46. images = images.reshape(-1, 3, 32, 32)
  47. images = images.transpose(0, 2, 3, 1)
  48. return images, labels
  49. def visualize_dataset(images, labels):
  50. """
  51. Helper function to visualize the dataset samples
  52. """
  53. num_samples = len(images)
  54. for i in range(num_samples):
  55. plt.subplot(1, num_samples, i + 1)
  56. plt.imshow(images[i])
  57. plt.title(labels[i])
  58. plt.show()
  59. ### Testcases for Cifar10Dataset Op ###
  60. def test_cifar10_content_check():
  61. """
  62. Validate Cifar10Dataset image readings
  63. """
  64. logger.info("Test Cifar10Dataset Op with content check")
  65. data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100, shuffle=False)
  66. images, labels = load_cifar(DATA_DIR_10)
  67. num_iter = 0
  68. # in this example, each dictionary has keys "image" and "label"
  69. for i, d in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
  70. np.testing.assert_array_equal(d["image"], images[i])
  71. np.testing.assert_array_equal(d["label"], labels[i])
  72. num_iter += 1
  73. assert num_iter == 100
  74. def test_cifar10_basic():
  75. """
  76. Validate CIFAR10
  77. """
  78. logger.info("Test Cifar10Dataset Op")
  79. # case 0: test loading the whole dataset
  80. data0 = ds.Cifar10Dataset(DATA_DIR_10)
  81. num_iter0 = 0
  82. for _ in data0.create_dict_iterator(num_epochs=1):
  83. num_iter0 += 1
  84. assert num_iter0 == 10000
  85. # case 1: test num_samples
  86. data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100)
  87. num_iter1 = 0
  88. for _ in data1.create_dict_iterator(num_epochs=1):
  89. num_iter1 += 1
  90. assert num_iter1 == 100
  91. # case 2: test num_parallel_workers
  92. data2 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=50, num_parallel_workers=1)
  93. num_iter2 = 0
  94. for _ in data2.create_dict_iterator(num_epochs=1):
  95. num_iter2 += 1
  96. assert num_iter2 == 50
  97. # case 3: test repeat
  98. data3 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100)
  99. data3 = data3.repeat(3)
  100. num_iter3 = 0
  101. for _ in data3.create_dict_iterator(num_epochs=1):
  102. num_iter3 += 1
  103. assert num_iter3 == 300
  104. # case 4: test batch with drop_remainder=False
  105. data4 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100)
  106. assert data4.get_dataset_size() == 100
  107. assert data4.get_batch_size() == 1
  108. data4 = data4.batch(batch_size=7) # drop_remainder is default to be False
  109. assert data4.get_dataset_size() == 15
  110. assert data4.get_batch_size() == 7
  111. num_iter4 = 0
  112. for _ in data4.create_dict_iterator(num_epochs=1):
  113. num_iter4 += 1
  114. assert num_iter4 == 15
  115. # case 5: test batch with drop_remainder=True
  116. data5 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100)
  117. assert data5.get_dataset_size() == 100
  118. assert data5.get_batch_size() == 1
  119. data5 = data5.batch(batch_size=7, drop_remainder=True) # the rest of incomplete batch will be dropped
  120. assert data5.get_dataset_size() == 14
  121. assert data5.get_batch_size() == 7
  122. num_iter5 = 0
  123. for _ in data5.create_dict_iterator(num_epochs=1):
  124. num_iter5 += 1
  125. assert num_iter5 == 14
  126. def test_cifar10_pk_sampler():
  127. """
  128. Test Cifar10Dataset with PKSampler
  129. """
  130. logger.info("Test Cifar10Dataset Op with PKSampler")
  131. golden = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4,
  132. 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9]
  133. sampler = ds.PKSampler(3)
  134. data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
  135. num_iter = 0
  136. label_list = []
  137. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  138. label_list.append(item["label"])
  139. num_iter += 1
  140. np.testing.assert_array_equal(golden, label_list)
  141. assert num_iter == 30
  142. def test_cifar10_sequential_sampler():
  143. """
  144. Test Cifar10Dataset with SequentialSampler
  145. """
  146. logger.info("Test Cifar10Dataset Op with SequentialSampler")
  147. num_samples = 30
  148. sampler = ds.SequentialSampler(num_samples=num_samples)
  149. data1 = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
  150. data2 = ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_samples=num_samples)
  151. num_iter = 0
  152. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  153. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  154. np.testing.assert_equal(item1["label"], item2["label"])
  155. num_iter += 1
  156. assert num_iter == num_samples
  157. def test_cifar10_exception():
  158. """
  159. Test error cases for Cifar10Dataset
  160. """
  161. logger.info("Test error cases for Cifar10Dataset")
  162. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  163. with pytest.raises(RuntimeError, match=error_msg_1):
  164. ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, sampler=ds.PKSampler(3))
  165. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  166. with pytest.raises(RuntimeError, match=error_msg_2):
  167. ds.Cifar10Dataset(DATA_DIR_10, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
  168. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  169. with pytest.raises(RuntimeError, match=error_msg_3):
  170. ds.Cifar10Dataset(DATA_DIR_10, num_shards=10)
  171. error_msg_4 = "shard_id is specified but num_shards is not"
  172. with pytest.raises(RuntimeError, match=error_msg_4):
  173. ds.Cifar10Dataset(DATA_DIR_10, shard_id=0)
  174. error_msg_5 = "Input shard_id is not within the required interval"
  175. with pytest.raises(ValueError, match=error_msg_5):
  176. ds.Cifar10Dataset(DATA_DIR_10, num_shards=2, shard_id=-1)
  177. with pytest.raises(ValueError, match=error_msg_5):
  178. ds.Cifar10Dataset(DATA_DIR_10, num_shards=2, shard_id=5)
  179. error_msg_6 = "num_parallel_workers exceeds"
  180. with pytest.raises(ValueError, match=error_msg_6):
  181. ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=0)
  182. with pytest.raises(ValueError, match=error_msg_6):
  183. ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=256)
  184. error_msg_7 = "no .bin files found"
  185. with pytest.raises(RuntimeError, match=error_msg_7):
  186. ds1 = ds.Cifar10Dataset(NO_BIN_DIR)
  187. for _ in ds1.__iter__():
  188. pass
  189. def test_cifar10_visualize(plot=False):
  190. """
  191. Visualize Cifar10Dataset results
  192. """
  193. logger.info("Test Cifar10Dataset visualization")
  194. data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=10, shuffle=False)
  195. num_iter = 0
  196. image_list, label_list = [], []
  197. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  198. image = item["image"]
  199. label = item["label"]
  200. image_list.append(image)
  201. label_list.append("label {}".format(label))
  202. assert isinstance(image, np.ndarray)
  203. assert image.shape == (32, 32, 3)
  204. assert image.dtype == np.uint8
  205. assert label.dtype == np.uint32
  206. num_iter += 1
  207. assert num_iter == 10
  208. if plot:
  209. visualize_dataset(image_list, label_list)
  210. ### Testcases for Cifar100Dataset Op ###
  211. def test_cifar100_content_check():
  212. """
  213. Validate Cifar100Dataset image readings
  214. """
  215. logger.info("Test Cifar100Dataset with content check")
  216. data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100, shuffle=False)
  217. images, labels = load_cifar(DATA_DIR_100, kind="cifar100")
  218. num_iter = 0
  219. # in this example, each dictionary has keys "image", "coarse_label" and "fine_image"
  220. for i, d in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
  221. np.testing.assert_array_equal(d["image"], images[i])
  222. np.testing.assert_array_equal(d["coarse_label"], labels[i][0])
  223. np.testing.assert_array_equal(d["fine_label"], labels[i][1])
  224. num_iter += 1
  225. assert num_iter == 100
  226. def test_cifar100_basic():
  227. """
  228. Test Cifar100Dataset
  229. """
  230. logger.info("Test Cifar100Dataset")
  231. # case 1: test num_samples
  232. data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
  233. num_iter1 = 0
  234. for _ in data1.create_dict_iterator(num_epochs=1):
  235. num_iter1 += 1
  236. assert num_iter1 == 100
  237. # case 2: test repeat
  238. data1 = data1.repeat(2)
  239. num_iter2 = 0
  240. for _ in data1.create_dict_iterator(num_epochs=1):
  241. num_iter2 += 1
  242. assert num_iter2 == 200
  243. # case 3: test num_parallel_workers
  244. data2 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100, num_parallel_workers=1)
  245. num_iter3 = 0
  246. for _ in data2.create_dict_iterator(num_epochs=1):
  247. num_iter3 += 1
  248. assert num_iter3 == 100
  249. # case 4: test batch with drop_remainder=False
  250. data3 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
  251. assert data3.get_dataset_size() == 100
  252. assert data3.get_batch_size() == 1
  253. data3 = data3.batch(batch_size=3)
  254. assert data3.get_dataset_size() == 34
  255. assert data3.get_batch_size() == 3
  256. num_iter4 = 0
  257. for _ in data3.create_dict_iterator(num_epochs=1):
  258. num_iter4 += 1
  259. assert num_iter4 == 34
  260. # case 4: test batch with drop_remainder=True
  261. data4 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
  262. data4 = data4.batch(batch_size=3, drop_remainder=True)
  263. assert data4.get_dataset_size() == 33
  264. assert data4.get_batch_size() == 3
  265. num_iter5 = 0
  266. for _ in data4.create_dict_iterator(num_epochs=1):
  267. num_iter5 += 1
  268. assert num_iter5 == 33
  269. def test_cifar100_pk_sampler():
  270. """
  271. Test Cifar100Dataset with PKSampler
  272. """
  273. logger.info("Test Cifar100Dataset with PKSampler")
  274. golden = [i for i in range(20)]
  275. sampler = ds.PKSampler(1)
  276. data = ds.Cifar100Dataset(DATA_DIR_100, sampler=sampler)
  277. num_iter = 0
  278. label_list = []
  279. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  280. label_list.append(item["coarse_label"])
  281. num_iter += 1
  282. np.testing.assert_array_equal(golden, label_list)
  283. assert num_iter == 20
  284. def test_cifar100_exception():
  285. """
  286. Test error cases for Cifar100Dataset
  287. """
  288. logger.info("Test error cases for Cifar100Dataset")
  289. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  290. with pytest.raises(RuntimeError, match=error_msg_1):
  291. ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, sampler=ds.PKSampler(3))
  292. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  293. with pytest.raises(RuntimeError, match=error_msg_2):
  294. ds.Cifar100Dataset(DATA_DIR_100, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
  295. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  296. with pytest.raises(RuntimeError, match=error_msg_3):
  297. ds.Cifar100Dataset(DATA_DIR_100, num_shards=10)
  298. error_msg_4 = "shard_id is specified but num_shards is not"
  299. with pytest.raises(RuntimeError, match=error_msg_4):
  300. ds.Cifar100Dataset(DATA_DIR_100, shard_id=0)
  301. error_msg_5 = "Input shard_id is not within the required interval"
  302. with pytest.raises(ValueError, match=error_msg_5):
  303. ds.Cifar100Dataset(DATA_DIR_100, num_shards=2, shard_id=-1)
  304. with pytest.raises(ValueError, match=error_msg_5):
  305. ds.Cifar10Dataset(DATA_DIR_100, num_shards=2, shard_id=5)
  306. error_msg_6 = "num_parallel_workers exceeds"
  307. with pytest.raises(ValueError, match=error_msg_6):
  308. ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=0)
  309. with pytest.raises(ValueError, match=error_msg_6):
  310. ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=256)
  311. error_msg_7 = "no .bin files found"
  312. with pytest.raises(RuntimeError, match=error_msg_7):
  313. ds1 = ds.Cifar100Dataset(NO_BIN_DIR)
  314. for _ in ds1.__iter__():
  315. pass
  316. def test_cifar100_visualize(plot=False):
  317. """
  318. Visualize Cifar100Dataset results
  319. """
  320. logger.info("Test Cifar100Dataset visualization")
  321. data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=10, shuffle=False)
  322. num_iter = 0
  323. image_list, label_list = [], []
  324. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  325. image = item["image"]
  326. coarse_label = item["coarse_label"]
  327. fine_label = item["fine_label"]
  328. image_list.append(image)
  329. label_list.append("coarse_label {}\nfine_label {}".format(coarse_label, fine_label))
  330. assert isinstance(image, np.ndarray)
  331. assert image.shape == (32, 32, 3)
  332. assert image.dtype == np.uint8
  333. assert coarse_label.dtype == np.uint32
  334. assert fine_label.dtype == np.uint32
  335. num_iter += 1
  336. assert num_iter == 10
  337. if plot:
  338. visualize_dataset(image_list, label_list)
  339. def test_cifar_usage():
  340. """
  341. test usage of cifar
  342. """
  343. logger.info("Test Cifar100Dataset usage flag")
  344. # flag, if True, test cifar10 else test cifar100
  345. def test_config(usage, flag=True, cifar_path=None):
  346. if cifar_path is None:
  347. cifar_path = DATA_DIR_10 if flag else DATA_DIR_100
  348. try:
  349. data = ds.Cifar10Dataset(cifar_path, usage=usage) if flag else ds.Cifar100Dataset(cifar_path, usage=usage)
  350. num_rows = 0
  351. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  352. num_rows += 1
  353. except (ValueError, TypeError, RuntimeError) as e:
  354. return str(e)
  355. return num_rows
  356. # test the usage of CIFAR100
  357. assert test_config("train") == 10000
  358. assert test_config("all") == 10000
  359. assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid")
  360. assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
  361. assert "no valid data matching the dataset API Cifar10Dataset" in test_config("test")
  362. # test the usage of CIFAR10
  363. assert test_config("test", False) == 10000
  364. assert test_config("all", False) == 10000
  365. assert "no valid data matching the dataset API Cifar100Dataset" in test_config("train", False)
  366. assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid", False)
  367. # change this directory to the folder that contains all cifar10 files
  368. all_cifar10 = None
  369. if all_cifar10 is not None:
  370. assert test_config("train", True, all_cifar10) == 50000
  371. assert test_config("test", True, all_cifar10) == 10000
  372. assert test_config("all", True, all_cifar10) == 60000
  373. assert ds.Cifar10Dataset(all_cifar10, usage="train").get_dataset_size() == 50000
  374. assert ds.Cifar10Dataset(all_cifar10, usage="test").get_dataset_size() == 10000
  375. assert ds.Cifar10Dataset(all_cifar10, usage="all").get_dataset_size() == 60000
  376. # change this directory to the folder that contains all cifar100 files
  377. all_cifar100 = None
  378. if all_cifar100 is not None:
  379. assert test_config("train", False, all_cifar100) == 50000
  380. assert test_config("test", False, all_cifar100) == 10000
  381. assert test_config("all", False, all_cifar100) == 60000
  382. assert ds.Cifar100Dataset(all_cifar100, usage="train").get_dataset_size() == 50000
  383. assert ds.Cifar100Dataset(all_cifar100, usage="test").get_dataset_size() == 10000
  384. assert ds.Cifar100Dataset(all_cifar100, usage="all").get_dataset_size() == 60000
  385. def test_cifar_exception_file_path():
  386. def exception_func(item):
  387. raise Exception("Error occur!")
  388. try:
  389. data = ds.Cifar10Dataset(DATA_DIR_10)
  390. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  391. num_rows = 0
  392. for _ in data.create_dict_iterator():
  393. num_rows += 1
  394. assert False
  395. except RuntimeError as e:
  396. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  397. try:
  398. data = ds.Cifar10Dataset(DATA_DIR_10)
  399. data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
  400. num_rows = 0
  401. for _ in data.create_dict_iterator():
  402. num_rows += 1
  403. assert False
  404. except RuntimeError as e:
  405. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  406. try:
  407. data = ds.Cifar100Dataset(DATA_DIR_100)
  408. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  409. num_rows = 0
  410. for _ in data.create_dict_iterator():
  411. num_rows += 1
  412. assert False
  413. except RuntimeError as e:
  414. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  415. try:
  416. data = ds.Cifar100Dataset(DATA_DIR_100)
  417. data = data.map(operations=exception_func, input_columns=["coarse_label"], num_parallel_workers=1)
  418. num_rows = 0
  419. for _ in data.create_dict_iterator():
  420. num_rows += 1
  421. assert False
  422. except RuntimeError as e:
  423. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  424. try:
  425. data = ds.Cifar100Dataset(DATA_DIR_100)
  426. data = data.map(operations=exception_func, input_columns=["fine_label"], num_parallel_workers=1)
  427. num_rows = 0
  428. for _ in data.create_dict_iterator():
  429. num_rows += 1
  430. assert False
  431. except RuntimeError as e:
  432. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  433. def test_cifar10_pk_sampler_get_dataset_size():
  434. """
  435. Test Cifar10Dataset with PKSampler and get_dataset_size
  436. """
  437. sampler = ds.PKSampler(3)
  438. data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
  439. num_iter = 0
  440. ds_sz = data.get_dataset_size()
  441. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  442. num_iter += 1
  443. assert ds_sz == num_iter == 30
  444. def test_cifar10_with_chained_sampler_get_dataset_size():
  445. """
  446. Test Cifar10Dataset with PKSampler chained with a SequentialSampler and get_dataset_size
  447. """
  448. sampler = ds.SequentialSampler(start_index=0, num_samples=5)
  449. child_sampler = ds.PKSampler(4)
  450. sampler.add_child(child_sampler)
  451. data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
  452. num_iter = 0
  453. ds_sz = data.get_dataset_size()
  454. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  455. num_iter += 1
  456. assert ds_sz == num_iter == 5
  457. if __name__ == '__main__':
  458. test_cifar10_content_check()
  459. test_cifar10_basic()
  460. test_cifar10_pk_sampler()
  461. test_cifar10_sequential_sampler()
  462. test_cifar10_exception()
  463. test_cifar10_visualize(plot=False)
  464. test_cifar100_content_check()
  465. test_cifar100_basic()
  466. test_cifar100_pk_sampler()
  467. test_cifar100_exception()
  468. test_cifar100_visualize(plot=False)
  469. test_cifar_usage()
  470. test_cifar_exception_file_path()
  471. test_cifar10_with_chained_sampler_get_dataset_size()
  472. test_cifar10_pk_sampler_get_dataset_size()