You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_imagefolder.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846
  1. # Copyright 2019-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import mindspore.dataset as ds
  17. import mindspore.dataset.vision.c_transforms as vision
  18. from mindspore import log as logger
  19. DATA_DIR = "../data/dataset/testPK/data"
  20. def test_imagefolder_basic():
  21. logger.info("Test Case basic")
  22. # define parameters
  23. repeat_count = 1
  24. # apply dataset operations
  25. data1 = ds.ImageFolderDataset(DATA_DIR)
  26. data1 = data1.repeat(repeat_count)
  27. num_iter = 0
  28. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  29. # in this example, each dictionary has keys "image" and "label"
  30. logger.info("image is {}".format(item["image"]))
  31. logger.info("label is {}".format(item["label"]))
  32. num_iter += 1
  33. logger.info("Number of data in data1: {}".format(num_iter))
  34. assert num_iter == 44
  35. def test_imagefolder_numsamples():
  36. logger.info("Test Case numSamples")
  37. # define parameters
  38. repeat_count = 1
  39. # apply dataset operations
  40. data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=10, num_parallel_workers=2)
  41. data1 = data1.repeat(repeat_count)
  42. num_iter = 0
  43. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  44. # in this example, each dictionary has keys "image" and "label"
  45. logger.info("image is {}".format(item["image"]))
  46. logger.info("label is {}".format(item["label"]))
  47. num_iter += 1
  48. logger.info("Number of data in data1: {}".format(num_iter))
  49. assert num_iter == 10
  50. random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
  51. data1 = ds.ImageFolderDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
  52. num_iter = 0
  53. for item in data1.create_dict_iterator(num_epochs=1):
  54. num_iter += 1
  55. assert num_iter == 3
  56. random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
  57. data1 = ds.ImageFolderDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
  58. num_iter = 0
  59. for item in data1.create_dict_iterator(num_epochs=1):
  60. num_iter += 1
  61. assert num_iter == 3
  62. def test_imagefolder_numshards():
  63. logger.info("Test Case numShards")
  64. # define parameters
  65. repeat_count = 1
  66. # apply dataset operations
  67. data1 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=3)
  68. data1 = data1.repeat(repeat_count)
  69. num_iter = 0
  70. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  71. # in this example, each dictionary has keys "image" and "label"
  72. logger.info("image is {}".format(item["image"]))
  73. logger.info("label is {}".format(item["label"]))
  74. num_iter += 1
  75. logger.info("Number of data in data1: {}".format(num_iter))
  76. assert num_iter == 11
  77. def test_imagefolder_shardid():
  78. logger.info("Test Case withShardID")
  79. # define parameters
  80. repeat_count = 1
  81. # apply dataset operations
  82. data1 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=1)
  83. data1 = data1.repeat(repeat_count)
  84. num_iter = 0
  85. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  86. # in this example, each dictionary has keys "image" and "label"
  87. logger.info("image is {}".format(item["image"]))
  88. logger.info("label is {}".format(item["label"]))
  89. num_iter += 1
  90. logger.info("Number of data in data1: {}".format(num_iter))
  91. assert num_iter == 11
  92. def test_imagefolder_noshuffle():
  93. logger.info("Test Case noShuffle")
  94. # define parameters
  95. repeat_count = 1
  96. # apply dataset operations
  97. data1 = ds.ImageFolderDataset(DATA_DIR, shuffle=False)
  98. data1 = data1.repeat(repeat_count)
  99. num_iter = 0
  100. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  101. # in this example, each dictionary has keys "image" and "label"
  102. logger.info("image is {}".format(item["image"]))
  103. logger.info("label is {}".format(item["label"]))
  104. num_iter += 1
  105. logger.info("Number of data in data1: {}".format(num_iter))
  106. assert num_iter == 44
  107. def test_imagefolder_extrashuffle():
  108. logger.info("Test Case extraShuffle")
  109. # define parameters
  110. repeat_count = 2
  111. # apply dataset operations
  112. data1 = ds.ImageFolderDataset(DATA_DIR, shuffle=True)
  113. data1 = data1.shuffle(buffer_size=5)
  114. data1 = data1.repeat(repeat_count)
  115. num_iter = 0
  116. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  117. # in this example, each dictionary has keys "image" and "label"
  118. logger.info("image is {}".format(item["image"]))
  119. logger.info("label is {}".format(item["label"]))
  120. num_iter += 1
  121. logger.info("Number of data in data1: {}".format(num_iter))
  122. assert num_iter == 88
  123. def test_imagefolder_classindex():
  124. logger.info("Test Case classIndex")
  125. # define parameters
  126. repeat_count = 1
  127. # apply dataset operations
  128. class_index = {"class3": 333, "class1": 111}
  129. data1 = ds.ImageFolderDataset(DATA_DIR, class_indexing=class_index, shuffle=False)
  130. data1 = data1.repeat(repeat_count)
  131. golden = [111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
  132. 333, 333, 333, 333, 333, 333, 333, 333, 333, 333, 333]
  133. num_iter = 0
  134. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  135. # in this example, each dictionary has keys "image" and "label"
  136. logger.info("image is {}".format(item["image"]))
  137. logger.info("label is {}".format(item["label"]))
  138. assert item["label"] == golden[num_iter]
  139. num_iter += 1
  140. logger.info("Number of data in data1: {}".format(num_iter))
  141. assert num_iter == 22
  142. def test_imagefolder_negative_classindex():
  143. logger.info("Test Case negative classIndex")
  144. # define parameters
  145. repeat_count = 1
  146. # apply dataset operations
  147. class_index = {"class3": -333, "class1": 111}
  148. data1 = ds.ImageFolderDataset(DATA_DIR, class_indexing=class_index, shuffle=False)
  149. data1 = data1.repeat(repeat_count)
  150. golden = [111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
  151. -333, -333, -333, -333, -333, -333, -333, -333, -333, -333, -333]
  152. num_iter = 0
  153. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  154. # in this example, each dictionary has keys "image" and "label"
  155. logger.info("image is {}".format(item["image"]))
  156. logger.info("label is {}".format(item["label"]))
  157. assert item["label"] == golden[num_iter]
  158. num_iter += 1
  159. logger.info("Number of data in data1: {}".format(num_iter))
  160. assert num_iter == 22
  161. def test_imagefolder_extensions():
  162. logger.info("Test Case extensions")
  163. # define parameters
  164. repeat_count = 1
  165. # apply dataset operations
  166. ext = [".jpg", ".JPEG"]
  167. data1 = ds.ImageFolderDataset(DATA_DIR, extensions=ext)
  168. data1 = data1.repeat(repeat_count)
  169. num_iter = 0
  170. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  171. # in this example, each dictionary has keys "image" and "label"
  172. logger.info("image is {}".format(item["image"]))
  173. logger.info("label is {}".format(item["label"]))
  174. num_iter += 1
  175. logger.info("Number of data in data1: {}".format(num_iter))
  176. assert num_iter == 44
  177. def test_imagefolder_decode():
  178. logger.info("Test Case decode")
  179. # define parameters
  180. repeat_count = 1
  181. # apply dataset operations
  182. ext = [".jpg", ".JPEG"]
  183. data1 = ds.ImageFolderDataset(DATA_DIR, extensions=ext, decode=True)
  184. data1 = data1.repeat(repeat_count)
  185. num_iter = 0
  186. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  187. # in this example, each dictionary has keys "image" and "label"
  188. logger.info("image is {}".format(item["image"]))
  189. logger.info("label is {}".format(item["label"]))
  190. num_iter += 1
  191. logger.info("Number of data in data1: {}".format(num_iter))
  192. assert num_iter == 44
  193. def test_sequential_sampler():
  194. logger.info("Test Case SequentialSampler")
  195. golden = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  196. 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  197. 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
  198. 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
  199. # define parameters
  200. repeat_count = 1
  201. # apply dataset operations
  202. sampler = ds.SequentialSampler()
  203. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  204. data1 = data1.repeat(repeat_count)
  205. result = []
  206. num_iter = 0
  207. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  208. # in this example, each dictionary has keys "image" and "label"
  209. result.append(item["label"])
  210. num_iter += 1
  211. assert num_iter == 44
  212. logger.info("Result: {}".format(result))
  213. assert result == golden
  214. def test_random_sampler():
  215. logger.info("Test Case RandomSampler")
  216. # define parameters
  217. repeat_count = 1
  218. # apply dataset operations
  219. sampler = ds.RandomSampler()
  220. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  221. data1 = data1.repeat(repeat_count)
  222. num_iter = 0
  223. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  224. # in this example, each dictionary has keys "image" and "label"
  225. logger.info("image is {}".format(item["image"]))
  226. logger.info("label is {}".format(item["label"]))
  227. num_iter += 1
  228. logger.info("Number of data in data1: {}".format(num_iter))
  229. assert num_iter == 44
  230. def test_distributed_sampler():
  231. logger.info("Test Case DistributedSampler")
  232. # define parameters
  233. repeat_count = 1
  234. # apply dataset operations
  235. sampler = ds.DistributedSampler(10, 1)
  236. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  237. data1 = data1.repeat(repeat_count)
  238. num_iter = 0
  239. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  240. # in this example, each dictionary has keys "image" and "label"
  241. logger.info("image is {}".format(item["image"]))
  242. logger.info("label is {}".format(item["label"]))
  243. num_iter += 1
  244. logger.info("Number of data in data1: {}".format(num_iter))
  245. assert num_iter == 5
  246. def test_pk_sampler():
  247. logger.info("Test Case PKSampler")
  248. # define parameters
  249. repeat_count = 1
  250. # apply dataset operations
  251. sampler = ds.PKSampler(3)
  252. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  253. data1 = data1.repeat(repeat_count)
  254. num_iter = 0
  255. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  256. # in this example, each dictionary has keys "image" and "label"
  257. logger.info("image is {}".format(item["image"]))
  258. logger.info("label is {}".format(item["label"]))
  259. num_iter += 1
  260. logger.info("Number of data in data1: {}".format(num_iter))
  261. assert num_iter == 12
  262. def test_subset_random_sampler():
  263. logger.info("Test Case SubsetRandomSampler")
  264. # define parameters
  265. repeat_count = 1
  266. # apply dataset operations
  267. indices = [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11]
  268. sampler = ds.SubsetRandomSampler(indices)
  269. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  270. data1 = data1.repeat(repeat_count)
  271. num_iter = 0
  272. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  273. # in this example, each dictionary has keys "image" and "label"
  274. logger.info("image is {}".format(item["image"]))
  275. logger.info("label is {}".format(item["label"]))
  276. num_iter += 1
  277. logger.info("Number of data in data1: {}".format(num_iter))
  278. assert num_iter == 12
  279. def test_weighted_random_sampler():
  280. logger.info("Test Case WeightedRandomSampler")
  281. # define parameters
  282. repeat_count = 1
  283. # apply dataset operations
  284. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  285. sampler = ds.WeightedRandomSampler(weights, 11)
  286. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  287. data1 = data1.repeat(repeat_count)
  288. num_iter = 0
  289. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  290. # in this example, each dictionary has keys "image" and "label"
  291. logger.info("image is {}".format(item["image"]))
  292. logger.info("label is {}".format(item["label"]))
  293. num_iter += 1
  294. logger.info("Number of data in data1: {}".format(num_iter))
  295. assert num_iter == 11
  296. def test_weighted_random_sampler_exception():
  297. """
  298. Test error cases for WeightedRandomSampler
  299. """
  300. logger.info("Test error cases for WeightedRandomSampler")
  301. error_msg_1 = "type of weights element must be number"
  302. with pytest.raises(TypeError, match=error_msg_1):
  303. weights = ""
  304. ds.WeightedRandomSampler(weights)
  305. error_msg_2 = "type of weights element must be number"
  306. with pytest.raises(TypeError, match=error_msg_2):
  307. weights = (0.9, 0.8, 1.1)
  308. ds.WeightedRandomSampler(weights)
  309. error_msg_3 = "WeightedRandomSampler: weights vector must not be empty"
  310. with pytest.raises(RuntimeError, match=error_msg_3):
  311. weights = []
  312. sampler = ds.WeightedRandomSampler(weights)
  313. sampler.parse()
  314. error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative number, got: "
  315. with pytest.raises(RuntimeError, match=error_msg_4):
  316. weights = [1.0, 0.1, 0.02, 0.3, -0.4]
  317. sampler = ds.WeightedRandomSampler(weights)
  318. sampler.parse()
  319. error_msg_5 = "WeightedRandomSampler: elements of weights vector must not be all zero"
  320. with pytest.raises(RuntimeError, match=error_msg_5):
  321. weights = [0, 0, 0, 0, 0]
  322. sampler = ds.WeightedRandomSampler(weights)
  323. sampler.parse()
  324. def test_chained_sampler_01():
  325. logger.info("Test Case Chained Sampler - Random and Sequential, with repeat")
  326. # Create chained sampler, random and sequential
  327. sampler = ds.RandomSampler()
  328. child_sampler = ds.SequentialSampler()
  329. sampler.add_child(child_sampler)
  330. # Create ImageFolderDataset with sampler
  331. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  332. data1 = data1.repeat(count=3)
  333. # Verify dataset size
  334. data1_size = data1.get_dataset_size()
  335. logger.info("dataset size is: {}".format(data1_size))
  336. assert data1_size == 132
  337. # Verify number of iterations
  338. num_iter = 0
  339. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  340. # in this example, each dictionary has keys "image" and "label"
  341. logger.info("image is {}".format(item["image"]))
  342. logger.info("label is {}".format(item["label"]))
  343. num_iter += 1
  344. logger.info("Number of data in data1: {}".format(num_iter))
  345. assert num_iter == 132
  346. def test_chained_sampler_02():
  347. logger.info("Test Case Chained Sampler - Random and Sequential, with batch then repeat")
  348. # Create chained sampler, random and sequential
  349. sampler = ds.RandomSampler()
  350. child_sampler = ds.SequentialSampler()
  351. sampler.add_child(child_sampler)
  352. # Create ImageFolderDataset with sampler
  353. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  354. data1 = data1.batch(batch_size=5, drop_remainder=True)
  355. data1 = data1.repeat(count=2)
  356. # Verify dataset size
  357. data1_size = data1.get_dataset_size()
  358. logger.info("dataset size is: {}".format(data1_size))
  359. assert data1_size == 16
  360. # Verify number of iterations
  361. num_iter = 0
  362. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  363. # in this example, each dictionary has keys "image" and "label"
  364. logger.info("image is {}".format(item["image"]))
  365. logger.info("label is {}".format(item["label"]))
  366. num_iter += 1
  367. logger.info("Number of data in data1: {}".format(num_iter))
  368. assert num_iter == 16
  369. def test_chained_sampler_03():
  370. logger.info("Test Case Chained Sampler - Random and Sequential, with repeat then batch")
  371. # Create chained sampler, random and sequential
  372. sampler = ds.RandomSampler()
  373. child_sampler = ds.SequentialSampler()
  374. sampler.add_child(child_sampler)
  375. # Create ImageFolderDataset with sampler
  376. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  377. data1 = data1.repeat(count=2)
  378. data1 = data1.batch(batch_size=5, drop_remainder=False)
  379. # Verify dataset size
  380. data1_size = data1.get_dataset_size()
  381. logger.info("dataset size is: {}".format(data1_size))
  382. assert data1_size == 18
  383. # Verify number of iterations
  384. num_iter = 0
  385. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  386. # in this example, each dictionary has keys "image" and "label"
  387. logger.info("image is {}".format(item["image"]))
  388. logger.info("label is {}".format(item["label"]))
  389. num_iter += 1
  390. logger.info("Number of data in data1: {}".format(num_iter))
  391. assert num_iter == 18
  392. def test_chained_sampler_04():
  393. logger.info("Test Case Chained Sampler - Distributed and Random, with batch then repeat")
  394. # Create chained sampler, distributed and random
  395. sampler = ds.DistributedSampler(num_shards=4, shard_id=3)
  396. child_sampler = ds.RandomSampler()
  397. sampler.add_child(child_sampler)
  398. # Create ImageFolderDataset with sampler
  399. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  400. data1 = data1.batch(batch_size=5, drop_remainder=True)
  401. data1 = data1.repeat(count=3)
  402. # Verify dataset size
  403. data1_size = data1.get_dataset_size()
  404. logger.info("dataset size is: {}".format(data1_size))
  405. assert data1_size == 6
  406. # Verify number of iterations
  407. num_iter = 0
  408. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  409. # in this example, each dictionary has keys "image" and "label"
  410. logger.info("image is {}".format(item["image"]))
  411. logger.info("label is {}".format(item["label"]))
  412. num_iter += 1
  413. logger.info("Number of data in data1: {}".format(num_iter))
  414. # Note: Each of the 4 shards has 44/4=11 samples
  415. # Note: Number of iterations is (11/5 = 2) * 3 = 6
  416. assert num_iter == 6
  417. def skip_test_chained_sampler_05():
  418. logger.info("Test Case Chained Sampler - PKSampler and WeightedRandom")
  419. # Create chained sampler, PKSampler and WeightedRandom
  420. sampler = ds.PKSampler(num_val=3) # Number of elements per class is 3 (and there are 4 classes)
  421. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 0.5]
  422. child_sampler = ds.WeightedRandomSampler(weights, num_samples=12)
  423. sampler.add_child(child_sampler)
  424. # Create ImageFolderDataset with sampler
  425. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  426. # Verify dataset size
  427. data1_size = data1.get_dataset_size()
  428. logger.info("dataset size is: {}".format(data1_size))
  429. assert data1_size == 12
  430. # Verify number of iterations
  431. num_iter = 0
  432. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  433. # in this example, each dictionary has keys "image" and "label"
  434. logger.info("image is {}".format(item["image"]))
  435. logger.info("label is {}".format(item["label"]))
  436. num_iter += 1
  437. logger.info("Number of data in data1: {}".format(num_iter))
  438. # Note: PKSampler produces 4x3=12 samples
  439. # Note: Child WeightedRandomSampler produces 12 samples
  440. assert num_iter == 12
  441. def test_chained_sampler_06():
  442. logger.info("Test Case Chained Sampler - WeightedRandom and PKSampler")
  443. # Create chained sampler, WeightedRandom and PKSampler
  444. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 0.5]
  445. sampler = ds.WeightedRandomSampler(weights=weights, num_samples=12)
  446. child_sampler = ds.PKSampler(num_val=3) # Number of elements per class is 3 (and there are 4 classes)
  447. sampler.add_child(child_sampler)
  448. # Create ImageFolderDataset with sampler
  449. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  450. # Verify dataset size
  451. data1_size = data1.get_dataset_size()
  452. logger.info("dataset size is: {}".format(data1_size))
  453. assert data1_size == 12
  454. # Verify number of iterations
  455. num_iter = 0
  456. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  457. # in this example, each dictionary has keys "image" and "label"
  458. logger.info("image is {}".format(item["image"]))
  459. logger.info("label is {}".format(item["label"]))
  460. num_iter += 1
  461. logger.info("Number of data in data1: {}".format(num_iter))
  462. # Note: WeightedRandomSampler produces 12 samples
  463. # Note: Child PKSampler produces 12 samples
  464. assert num_iter == 12
  465. def test_chained_sampler_07():
  466. logger.info("Test Case Chained Sampler - SubsetRandom and Distributed, 2 shards")
  467. # Create chained sampler, subset random and distributed
  468. indices = [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11]
  469. sampler = ds.SubsetRandomSampler(indices, num_samples=12)
  470. child_sampler = ds.DistributedSampler(num_shards=2, shard_id=1)
  471. sampler.add_child(child_sampler)
  472. # Create ImageFolderDataset with sampler
  473. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  474. # Verify dataset size
  475. data1_size = data1.get_dataset_size()
  476. logger.info("dataset size is: {}".format(data1_size))
  477. assert data1_size == 12
  478. # Verify number of iterations
  479. num_iter = 0
  480. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  481. # in this example, each dictionary has keys "image" and "label"
  482. logger.info("image is {}".format(item["image"]))
  483. logger.info("label is {}".format(item["label"]))
  484. num_iter += 1
  485. logger.info("Number of data in data1: {}".format(num_iter))
  486. # Note: SubsetRandomSampler produces 12 samples
  487. # Note: Each of 2 shards has 6 samples
  488. # FIXME: Uncomment the following assert when code issue is resolved; at runtime, number of samples is 12 not 6
  489. # assert num_iter == 6
  490. def skip_test_chained_sampler_08():
  491. logger.info("Test Case Chained Sampler - SubsetRandom and Distributed, 4 shards")
  492. # Create chained sampler, subset random and distributed
  493. indices = [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11]
  494. sampler = ds.SubsetRandomSampler(indices, num_samples=12)
  495. child_sampler = ds.DistributedSampler(num_shards=4, shard_id=1)
  496. sampler.add_child(child_sampler)
  497. # Create ImageFolderDataset with sampler
  498. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  499. # Verify dataset size
  500. data1_size = data1.get_dataset_size()
  501. logger.info("dataset size is: {}".format(data1_size))
  502. assert data1_size == 3
  503. # Verify number of iterations
  504. num_iter = 0
  505. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  506. # in this example, each dictionary has keys "image" and "label"
  507. logger.info("image is {}".format(item["image"]))
  508. logger.info("label is {}".format(item["label"]))
  509. num_iter += 1
  510. logger.info("Number of data in data1: {}".format(num_iter))
  511. # Note: SubsetRandomSampler returns 12 samples
  512. # Note: Each of 4 shards has 3 samples
  513. assert num_iter == 3
  514. def test_imagefolder_rename():
  515. logger.info("Test Case rename")
  516. # define parameters
  517. repeat_count = 1
  518. # apply dataset operations
  519. data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=10)
  520. data1 = data1.repeat(repeat_count)
  521. num_iter = 0
  522. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  523. # in this example, each dictionary has keys "image" and "label"
  524. logger.info("image is {}".format(item["image"]))
  525. logger.info("label is {}".format(item["label"]))
  526. num_iter += 1
  527. logger.info("Number of data in data1: {}".format(num_iter))
  528. assert num_iter == 10
  529. data1 = data1.rename(input_columns=["image"], output_columns="image2")
  530. num_iter = 0
  531. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  532. # in this example, each dictionary has keys "image" and "label"
  533. logger.info("image is {}".format(item["image2"]))
  534. logger.info("label is {}".format(item["label"]))
  535. num_iter += 1
  536. logger.info("Number of data in data1: {}".format(num_iter))
  537. assert num_iter == 10
  538. def test_imagefolder_zip():
  539. logger.info("Test Case zip")
  540. # define parameters
  541. repeat_count = 2
  542. # apply dataset operations
  543. data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=10)
  544. data2 = ds.ImageFolderDataset(DATA_DIR, num_samples=10)
  545. data1 = data1.repeat(repeat_count)
  546. # rename dataset2 for no conflict
  547. data2 = data2.rename(input_columns=["image", "label"], output_columns=["image1", "label1"])
  548. data3 = ds.zip((data1, data2))
  549. num_iter = 0
  550. for item in data3.create_dict_iterator(num_epochs=1): # each data is a dictionary
  551. # in this example, each dictionary has keys "image" and "label"
  552. logger.info("image is {}".format(item["image"]))
  553. logger.info("label is {}".format(item["label"]))
  554. num_iter += 1
  555. logger.info("Number of data in data1: {}".format(num_iter))
  556. assert num_iter == 10
  557. def test_imagefolder_exception():
  558. logger.info("Test imagefolder exception")
  559. def exception_func(item):
  560. raise Exception("Error occur!")
  561. def exception_func2(image, label):
  562. raise Exception("Error occur!")
  563. try:
  564. data = ds.ImageFolderDataset(DATA_DIR)
  565. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  566. for _ in data.__iter__():
  567. pass
  568. assert False
  569. except RuntimeError as e:
  570. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  571. try:
  572. data = ds.ImageFolderDataset(DATA_DIR)
  573. data = data.map(operations=exception_func2, input_columns=["image", "label"],
  574. output_columns=["image", "label", "label1"],
  575. column_order=["image", "label", "label1"], num_parallel_workers=1)
  576. for _ in data.__iter__():
  577. pass
  578. assert False
  579. except RuntimeError as e:
  580. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  581. try:
  582. data = ds.ImageFolderDataset(DATA_DIR)
  583. data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
  584. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  585. for _ in data.__iter__():
  586. pass
  587. assert False
  588. except RuntimeError as e:
  589. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  590. if __name__ == '__main__':
  591. test_imagefolder_basic()
  592. logger.info('test_imagefolder_basic Ended.\n')
  593. test_imagefolder_numsamples()
  594. logger.info('test_imagefolder_numsamples Ended.\n')
  595. test_sequential_sampler()
  596. logger.info('test_sequential_sampler Ended.\n')
  597. test_random_sampler()
  598. logger.info('test_random_sampler Ended.\n')
  599. test_distributed_sampler()
  600. logger.info('test_distributed_sampler Ended.\n')
  601. test_pk_sampler()
  602. logger.info('test_pk_sampler Ended.\n')
  603. test_subset_random_sampler()
  604. logger.info('test_subset_random_sampler Ended.\n')
  605. test_weighted_random_sampler()
  606. logger.info('test_weighted_random_sampler Ended.\n')
  607. test_weighted_random_sampler_exception()
  608. logger.info('test_weighted_random_sampler_exception Ended.\n')
  609. test_chained_sampler_01()
  610. logger.info('test_chained_sampler_01 Ended.\n')
  611. test_chained_sampler_02()
  612. logger.info('test_chained_sampler_02 Ended.\n')
  613. test_chained_sampler_03()
  614. logger.info('test_chained_sampler_03 Ended.\n')
  615. test_chained_sampler_04()
  616. logger.info('test_chained_sampler_04 Ended.\n')
  617. # test_chained_sampler_05()
  618. # logger.info('test_chained_sampler_05 Ended.\n')
  619. test_chained_sampler_06()
  620. logger.info('test_chained_sampler_06 Ended.\n')
  621. test_chained_sampler_07()
  622. logger.info('test_chained_sampler_07 Ended.\n')
  623. # test_chained_sampler_08()
  624. # logger.info('test_chained_sampler_07 Ended.\n')
  625. test_imagefolder_numshards()
  626. logger.info('test_imagefolder_numshards Ended.\n')
  627. test_imagefolder_shardid()
  628. logger.info('test_imagefolder_shardid Ended.\n')
  629. test_imagefolder_noshuffle()
  630. logger.info('test_imagefolder_noshuffle Ended.\n')
  631. test_imagefolder_extrashuffle()
  632. logger.info('test_imagefolder_extrashuffle Ended.\n')
  633. test_imagefolder_classindex()
  634. logger.info('test_imagefolder_classindex Ended.\n')
  635. test_imagefolder_negative_classindex()
  636. logger.info('test_imagefolder_negative_classindex Ended.\n')
  637. test_imagefolder_extensions()
  638. logger.info('test_imagefolder_extensions Ended.\n')
  639. test_imagefolder_decode()
  640. logger.info('test_imagefolder_decode Ended.\n')
  641. test_imagefolder_rename()
  642. logger.info('test_imagefolder_rename Ended.\n')
  643. test_imagefolder_zip()
  644. logger.info('test_imagefolder_zip Ended.\n')
  645. test_imagefolder_exception()
  646. logger.info('test_imagefolder_exception Ended.\n')