You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_imagefolder.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import mindspore.dataset as ds
  16. from mindspore import log as logger
  17. DATA_DIR = "../data/dataset/testPK/data"
  18. def test_imagefolder_basic():
  19. logger.info("Test Case basic")
  20. # define parameters
  21. repeat_count = 1
  22. # apply dataset operations
  23. data1 = ds.ImageFolderDatasetV2(DATA_DIR)
  24. data1 = data1.repeat(repeat_count)
  25. num_iter = 0
  26. for item in data1.create_dict_iterator(): # each data is a dictionary
  27. # in this example, each dictionary has keys "image" and "label"
  28. logger.info("image is {}".format(item["image"]))
  29. logger.info("label is {}".format(item["label"]))
  30. num_iter += 1
  31. logger.info("Number of data in data1: {}".format(num_iter))
  32. assert num_iter == 44
  33. def test_imagefolder_numsamples():
  34. logger.info("Test Case numSamples")
  35. # define parameters
  36. repeat_count = 1
  37. # apply dataset operations
  38. data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2)
  39. data1 = data1.repeat(repeat_count)
  40. num_iter = 0
  41. for item in data1.create_dict_iterator(): # each data is a dictionary
  42. # in this example, each dictionary has keys "image" and "label"
  43. logger.info("image is {}".format(item["image"]))
  44. logger.info("label is {}".format(item["label"]))
  45. num_iter += 1
  46. logger.info("Number of data in data1: {}".format(num_iter))
  47. assert num_iter == 10
  48. def test_imagefolder_numshards():
  49. logger.info("Test Case numShards")
  50. # define parameters
  51. repeat_count = 1
  52. # apply dataset operations
  53. data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_shards=4, shard_id=3)
  54. data1 = data1.repeat(repeat_count)
  55. num_iter = 0
  56. for item in data1.create_dict_iterator(): # each data is a dictionary
  57. # in this example, each dictionary has keys "image" and "label"
  58. logger.info("image is {}".format(item["image"]))
  59. logger.info("label is {}".format(item["label"]))
  60. num_iter += 1
  61. logger.info("Number of data in data1: {}".format(num_iter))
  62. assert num_iter == 11
  63. def test_imagefolder_shardid():
  64. logger.info("Test Case withShardID")
  65. # define parameters
  66. repeat_count = 1
  67. # apply dataset operations
  68. data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_shards=4, shard_id=1)
  69. data1 = data1.repeat(repeat_count)
  70. num_iter = 0
  71. for item in data1.create_dict_iterator(): # each data is a dictionary
  72. # in this example, each dictionary has keys "image" and "label"
  73. logger.info("image is {}".format(item["image"]))
  74. logger.info("label is {}".format(item["label"]))
  75. num_iter += 1
  76. logger.info("Number of data in data1: {}".format(num_iter))
  77. assert num_iter == 11
  78. def test_imagefolder_noshuffle():
  79. logger.info("Test Case noShuffle")
  80. # define parameters
  81. repeat_count = 1
  82. # apply dataset operations
  83. data1 = ds.ImageFolderDatasetV2(DATA_DIR, shuffle=False)
  84. data1 = data1.repeat(repeat_count)
  85. num_iter = 0
  86. for item in data1.create_dict_iterator(): # each data is a dictionary
  87. # in this example, each dictionary has keys "image" and "label"
  88. logger.info("image is {}".format(item["image"]))
  89. logger.info("label is {}".format(item["label"]))
  90. num_iter += 1
  91. logger.info("Number of data in data1: {}".format(num_iter))
  92. assert num_iter == 44
  93. def test_imagefolder_extrashuffle():
  94. logger.info("Test Case extraShuffle")
  95. # define parameters
  96. repeat_count = 2
  97. # apply dataset operations
  98. data1 = ds.ImageFolderDatasetV2(DATA_DIR, shuffle=True)
  99. data1 = data1.shuffle(buffer_size=5)
  100. data1 = data1.repeat(repeat_count)
  101. num_iter = 0
  102. for item in data1.create_dict_iterator(): # each data is a dictionary
  103. # in this example, each dictionary has keys "image" and "label"
  104. logger.info("image is {}".format(item["image"]))
  105. logger.info("label is {}".format(item["label"]))
  106. num_iter += 1
  107. logger.info("Number of data in data1: {}".format(num_iter))
  108. assert num_iter == 88
  109. def test_imagefolder_classindex():
  110. logger.info("Test Case classIndex")
  111. # define parameters
  112. repeat_count = 1
  113. # apply dataset operations
  114. class_index = {"class3": 333, "class1": 111}
  115. data1 = ds.ImageFolderDatasetV2(DATA_DIR, class_indexing=class_index, shuffle=False)
  116. data1 = data1.repeat(repeat_count)
  117. golden = [111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
  118. 333, 333, 333, 333, 333, 333, 333, 333, 333, 333, 333]
  119. num_iter = 0
  120. for item in data1.create_dict_iterator(): # each data is a dictionary
  121. # in this example, each dictionary has keys "image" and "label"
  122. logger.info("image is {}".format(item["image"]))
  123. logger.info("label is {}".format(item["label"]))
  124. assert item["label"] == golden[num_iter]
  125. num_iter += 1
  126. logger.info("Number of data in data1: {}".format(num_iter))
  127. assert num_iter == 22
  128. def test_imagefolder_negative_classindex():
  129. logger.info("Test Case negative classIndex")
  130. # define parameters
  131. repeat_count = 1
  132. # apply dataset operations
  133. class_index = {"class3": -333, "class1": 111}
  134. data1 = ds.ImageFolderDatasetV2(DATA_DIR, class_indexing=class_index, shuffle=False)
  135. data1 = data1.repeat(repeat_count)
  136. golden = [111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
  137. -333, -333, -333, -333, -333, -333, -333, -333, -333, -333, -333]
  138. num_iter = 0
  139. for item in data1.create_dict_iterator(): # each data is a dictionary
  140. # in this example, each dictionary has keys "image" and "label"
  141. logger.info("image is {}".format(item["image"]))
  142. logger.info("label is {}".format(item["label"]))
  143. assert item["label"] == golden[num_iter]
  144. num_iter += 1
  145. logger.info("Number of data in data1: {}".format(num_iter))
  146. assert num_iter == 22
  147. def test_imagefolder_extensions():
  148. logger.info("Test Case extensions")
  149. # define parameters
  150. repeat_count = 1
  151. # apply dataset operations
  152. ext = [".jpg", ".JPEG"]
  153. data1 = ds.ImageFolderDatasetV2(DATA_DIR, extensions=ext)
  154. data1 = data1.repeat(repeat_count)
  155. num_iter = 0
  156. for item in data1.create_dict_iterator(): # each data is a dictionary
  157. # in this example, each dictionary has keys "image" and "label"
  158. logger.info("image is {}".format(item["image"]))
  159. logger.info("label is {}".format(item["label"]))
  160. num_iter += 1
  161. logger.info("Number of data in data1: {}".format(num_iter))
  162. assert num_iter == 44
  163. def test_imagefolder_decode():
  164. logger.info("Test Case decode")
  165. # define parameters
  166. repeat_count = 1
  167. # apply dataset operations
  168. ext = [".jpg", ".JPEG"]
  169. data1 = ds.ImageFolderDatasetV2(DATA_DIR, extensions=ext, decode=True)
  170. data1 = data1.repeat(repeat_count)
  171. num_iter = 0
  172. for item in data1.create_dict_iterator(): # each data is a dictionary
  173. # in this example, each dictionary has keys "image" and "label"
  174. logger.info("image is {}".format(item["image"]))
  175. logger.info("label is {}".format(item["label"]))
  176. num_iter += 1
  177. logger.info("Number of data in data1: {}".format(num_iter))
  178. assert num_iter == 44
  179. def test_sequential_sampler():
  180. logger.info("Test Case SequentialSampler")
  181. golden = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  182. 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  183. 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
  184. 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
  185. # define parameters
  186. repeat_count = 1
  187. # apply dataset operations
  188. sampler = ds.SequentialSampler()
  189. data1 = ds.ImageFolderDatasetV2(DATA_DIR, sampler=sampler)
  190. data1 = data1.repeat(repeat_count)
  191. result = []
  192. num_iter = 0
  193. for item in data1.create_dict_iterator(): # each data is a dictionary
  194. # in this example, each dictionary has keys "image" and "label"
  195. result.append(item["label"])
  196. num_iter += 1
  197. logger.info("Result: {}".format(result))
  198. assert result == golden
  199. def test_random_sampler():
  200. logger.info("Test Case RandomSampler")
  201. # define parameters
  202. repeat_count = 1
  203. # apply dataset operations
  204. sampler = ds.RandomSampler()
  205. data1 = ds.ImageFolderDatasetV2(DATA_DIR, sampler=sampler)
  206. data1 = data1.repeat(repeat_count)
  207. num_iter = 0
  208. for item in data1.create_dict_iterator(): # each data is a dictionary
  209. # in this example, each dictionary has keys "image" and "label"
  210. logger.info("image is {}".format(item["image"]))
  211. logger.info("label is {}".format(item["label"]))
  212. num_iter += 1
  213. logger.info("Number of data in data1: {}".format(num_iter))
  214. assert num_iter == 44
  215. def test_distributed_sampler():
  216. logger.info("Test Case DistributedSampler")
  217. # define parameters
  218. repeat_count = 1
  219. # apply dataset operations
  220. sampler = ds.DistributedSampler(10, 1)
  221. data1 = ds.ImageFolderDatasetV2(DATA_DIR, sampler=sampler)
  222. data1 = data1.repeat(repeat_count)
  223. num_iter = 0
  224. for item in data1.create_dict_iterator(): # each data is a dictionary
  225. # in this example, each dictionary has keys "image" and "label"
  226. logger.info("image is {}".format(item["image"]))
  227. logger.info("label is {}".format(item["label"]))
  228. num_iter += 1
  229. logger.info("Number of data in data1: {}".format(num_iter))
  230. assert num_iter == 5
  231. def test_pk_sampler():
  232. logger.info("Test Case PKSampler")
  233. # define parameters
  234. repeat_count = 1
  235. # apply dataset operations
  236. sampler = ds.PKSampler(3)
  237. data1 = ds.ImageFolderDatasetV2(DATA_DIR, sampler=sampler)
  238. data1 = data1.repeat(repeat_count)
  239. num_iter = 0
  240. for item in data1.create_dict_iterator(): # each data is a dictionary
  241. # in this example, each dictionary has keys "image" and "label"
  242. logger.info("image is {}".format(item["image"]))
  243. logger.info("label is {}".format(item["label"]))
  244. num_iter += 1
  245. logger.info("Number of data in data1: {}".format(num_iter))
  246. assert num_iter == 12
  247. def test_subset_random_sampler():
  248. logger.info("Test Case SubsetRandomSampler")
  249. # define parameters
  250. repeat_count = 1
  251. # apply dataset operations
  252. indices = [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11]
  253. sampler = ds.SubsetRandomSampler(indices)
  254. data1 = ds.ImageFolderDatasetV2(DATA_DIR, sampler=sampler)
  255. data1 = data1.repeat(repeat_count)
  256. num_iter = 0
  257. for item in data1.create_dict_iterator(): # each data is a dictionary
  258. # in this example, each dictionary has keys "image" and "label"
  259. logger.info("image is {}".format(item["image"]))
  260. logger.info("label is {}".format(item["label"]))
  261. num_iter += 1
  262. logger.info("Number of data in data1: {}".format(num_iter))
  263. assert num_iter == 12
  264. def test_weighted_random_sampler():
  265. logger.info("Test Case WeightedRandomSampler")
  266. # define parameters
  267. repeat_count = 1
  268. # apply dataset operations
  269. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  270. sampler = ds.WeightedRandomSampler(weights, 11)
  271. data1 = ds.ImageFolderDatasetV2(DATA_DIR, sampler=sampler)
  272. data1 = data1.repeat(repeat_count)
  273. num_iter = 0
  274. for item in data1.create_dict_iterator(): # each data is a dictionary
  275. # in this example, each dictionary has keys "image" and "label"
  276. logger.info("image is {}".format(item["image"]))
  277. logger.info("label is {}".format(item["label"]))
  278. num_iter += 1
  279. logger.info("Number of data in data1: {}".format(num_iter))
  280. assert num_iter == 11
  281. def test_imagefolder_rename():
  282. logger.info("Test Case rename")
  283. # define parameters
  284. repeat_count = 1
  285. # apply dataset operations
  286. data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10)
  287. data1 = data1.repeat(repeat_count)
  288. num_iter = 0
  289. for item in data1.create_dict_iterator(): # each data is a dictionary
  290. # in this example, each dictionary has keys "image" and "label"
  291. logger.info("image is {}".format(item["image"]))
  292. logger.info("label is {}".format(item["label"]))
  293. num_iter += 1
  294. logger.info("Number of data in data1: {}".format(num_iter))
  295. assert num_iter == 10
  296. data1 = data1.rename(input_columns=["image"], output_columns="image2")
  297. num_iter = 0
  298. for item in data1.create_dict_iterator(): # each data is a dictionary
  299. # in this example, each dictionary has keys "image" and "label"
  300. logger.info("image is {}".format(item["image2"]))
  301. logger.info("label is {}".format(item["label"]))
  302. num_iter += 1
  303. logger.info("Number of data in data1: {}".format(num_iter))
  304. assert num_iter == 10
  305. def test_imagefolder_zip():
  306. logger.info("Test Case zip")
  307. # define parameters
  308. repeat_count = 2
  309. # apply dataset operations
  310. data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10)
  311. data2 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10)
  312. data1 = data1.repeat(repeat_count)
  313. # rename dataset2 for no conflict
  314. data2 = data2.rename(input_columns=["image", "label"], output_columns=["image1", "label1"])
  315. data3 = ds.zip((data1, data2))
  316. num_iter = 0
  317. for item in data3.create_dict_iterator(): # each data is a dictionary
  318. # in this example, each dictionary has keys "image" and "label"
  319. logger.info("image is {}".format(item["image"]))
  320. logger.info("label is {}".format(item["label"]))
  321. num_iter += 1
  322. logger.info("Number of data in data1: {}".format(num_iter))
  323. assert num_iter == 10
  324. if __name__ == '__main__':
  325. test_imagefolder_basic()
  326. logger.info('test_imagefolder_basic Ended.\n')
  327. test_imagefolder_numsamples()
  328. logger.info('test_imagefolder_numsamples Ended.\n')
  329. test_sequential_sampler()
  330. logger.info('test_sequential_sampler Ended.\n')
  331. test_random_sampler()
  332. logger.info('test_random_sampler Ended.\n')
  333. test_distributed_sampler()
  334. logger.info('test_distributed_sampler Ended.\n')
  335. test_pk_sampler()
  336. logger.info('test_pk_sampler Ended.\n')
  337. test_subset_random_sampler()
  338. logger.info('test_subset_random_sampler Ended.\n')
  339. test_weighted_random_sampler()
  340. logger.info('test_weighted_random_sampler Ended.\n')
  341. test_imagefolder_numshards()
  342. logger.info('test_imagefolder_numshards Ended.\n')
  343. test_imagefolder_shardid()
  344. logger.info('test_imagefolder_shardid Ended.\n')
  345. test_imagefolder_noshuffle()
  346. logger.info('test_imagefolder_noshuffle Ended.\n')
  347. test_imagefolder_extrashuffle()
  348. logger.info('test_imagefolder_extrashuffle Ended.\n')
  349. test_imagefolder_classindex()
  350. logger.info('test_imagefolder_classindex Ended.\n')
  351. test_imagefolder_negative_classindex()
  352. logger.info('test_imagefolder_negative_classindex Ended.\n')
  353. test_imagefolder_extensions()
  354. logger.info('test_imagefolder_extensions Ended.\n')
  355. test_imagefolder_decode()
  356. logger.info('test_imagefolder_decode Ended.\n')
  357. test_imagefolder_rename()
  358. logger.info('test_imagefolder_rename Ended.\n')
  359. test_imagefolder_zip()
  360. logger.info('test_imagefolder_zip Ended.\n')