You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_imagefolder.py 18 kB


  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import mindspore.dataset as ds
  17. from mindspore import log as logger
  18. DATA_DIR = "../data/dataset/testPK/data"
  19. def test_imagefolder_basic():
  20. logger.info("Test Case basic")
  21. # define parameters
  22. repeat_count = 1
  23. # apply dataset operations
  24. data1 = ds.ImageFolderDataset(DATA_DIR)
  25. data1 = data1.repeat(repeat_count)
  26. num_iter = 0
  27. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  28. # in this example, each dictionary has keys "image" and "label"
  29. logger.info("image is {}".format(item["image"]))
  30. logger.info("label is {}".format(item["label"]))
  31. num_iter += 1
  32. logger.info("Number of data in data1: {}".format(num_iter))
  33. assert num_iter == 44
  34. def test_imagefolder_numsamples():
  35. logger.info("Test Case numSamples")
  36. # define parameters
  37. repeat_count = 1
  38. # apply dataset operations
  39. data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=10, num_parallel_workers=2)
  40. data1 = data1.repeat(repeat_count)
  41. num_iter = 0
  42. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  43. # in this example, each dictionary has keys "image" and "label"
  44. logger.info("image is {}".format(item["image"]))
  45. logger.info("label is {}".format(item["label"]))
  46. num_iter += 1
  47. logger.info("Number of data in data1: {}".format(num_iter))
  48. assert num_iter == 10
  49. random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
  50. data1 = ds.ImageFolderDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
  51. num_iter = 0
  52. for item in data1.create_dict_iterator(num_epochs=1):
  53. num_iter += 1
  54. assert num_iter == 3
  55. random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
  56. data1 = ds.ImageFolderDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
  57. num_iter = 0
  58. for item in data1.create_dict_iterator(num_epochs=1):
  59. num_iter += 1
  60. assert num_iter == 3
  61. def test_imagefolder_numshards():
  62. logger.info("Test Case numShards")
  63. # define parameters
  64. repeat_count = 1
  65. # apply dataset operations
  66. data1 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=3)
  67. data1 = data1.repeat(repeat_count)
  68. num_iter = 0
  69. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  70. # in this example, each dictionary has keys "image" and "label"
  71. logger.info("image is {}".format(item["image"]))
  72. logger.info("label is {}".format(item["label"]))
  73. num_iter += 1
  74. logger.info("Number of data in data1: {}".format(num_iter))
  75. assert num_iter == 11
  76. def test_imagefolder_shardid():
  77. logger.info("Test Case withShardID")
  78. # define parameters
  79. repeat_count = 1
  80. # apply dataset operations
  81. data1 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=1)
  82. data1 = data1.repeat(repeat_count)
  83. num_iter = 0
  84. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  85. # in this example, each dictionary has keys "image" and "label"
  86. logger.info("image is {}".format(item["image"]))
  87. logger.info("label is {}".format(item["label"]))
  88. num_iter += 1
  89. logger.info("Number of data in data1: {}".format(num_iter))
  90. assert num_iter == 11
  91. def test_imagefolder_noshuffle():
  92. logger.info("Test Case noShuffle")
  93. # define parameters
  94. repeat_count = 1
  95. # apply dataset operations
  96. data1 = ds.ImageFolderDataset(DATA_DIR, shuffle=False)
  97. data1 = data1.repeat(repeat_count)
  98. num_iter = 0
  99. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  100. # in this example, each dictionary has keys "image" and "label"
  101. logger.info("image is {}".format(item["image"]))
  102. logger.info("label is {}".format(item["label"]))
  103. num_iter += 1
  104. logger.info("Number of data in data1: {}".format(num_iter))
  105. assert num_iter == 44
  106. def test_imagefolder_extrashuffle():
  107. logger.info("Test Case extraShuffle")
  108. # define parameters
  109. repeat_count = 2
  110. # apply dataset operations
  111. data1 = ds.ImageFolderDataset(DATA_DIR, shuffle=True)
  112. data1 = data1.shuffle(buffer_size=5)
  113. data1 = data1.repeat(repeat_count)
  114. num_iter = 0
  115. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  116. # in this example, each dictionary has keys "image" and "label"
  117. logger.info("image is {}".format(item["image"]))
  118. logger.info("label is {}".format(item["label"]))
  119. num_iter += 1
  120. logger.info("Number of data in data1: {}".format(num_iter))
  121. assert num_iter == 88
  122. def test_imagefolder_classindex():
  123. logger.info("Test Case classIndex")
  124. # define parameters
  125. repeat_count = 1
  126. # apply dataset operations
  127. class_index = {"class3": 333, "class1": 111}
  128. data1 = ds.ImageFolderDataset(DATA_DIR, class_indexing=class_index, shuffle=False)
  129. data1 = data1.repeat(repeat_count)
  130. golden = [111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
  131. 333, 333, 333, 333, 333, 333, 333, 333, 333, 333, 333]
  132. num_iter = 0
  133. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  134. # in this example, each dictionary has keys "image" and "label"
  135. logger.info("image is {}".format(item["image"]))
  136. logger.info("label is {}".format(item["label"]))
  137. assert item["label"] == golden[num_iter]
  138. num_iter += 1
  139. logger.info("Number of data in data1: {}".format(num_iter))
  140. assert num_iter == 22
  141. def test_imagefolder_negative_classindex():
  142. logger.info("Test Case negative classIndex")
  143. # define parameters
  144. repeat_count = 1
  145. # apply dataset operations
  146. class_index = {"class3": -333, "class1": 111}
  147. data1 = ds.ImageFolderDataset(DATA_DIR, class_indexing=class_index, shuffle=False)
  148. data1 = data1.repeat(repeat_count)
  149. golden = [111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
  150. -333, -333, -333, -333, -333, -333, -333, -333, -333, -333, -333]
  151. num_iter = 0
  152. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  153. # in this example, each dictionary has keys "image" and "label"
  154. logger.info("image is {}".format(item["image"]))
  155. logger.info("label is {}".format(item["label"]))
  156. assert item["label"] == golden[num_iter]
  157. num_iter += 1
  158. logger.info("Number of data in data1: {}".format(num_iter))
  159. assert num_iter == 22
  160. def test_imagefolder_extensions():
  161. logger.info("Test Case extensions")
  162. # define parameters
  163. repeat_count = 1
  164. # apply dataset operations
  165. ext = [".jpg", ".JPEG"]
  166. data1 = ds.ImageFolderDataset(DATA_DIR, extensions=ext)
  167. data1 = data1.repeat(repeat_count)
  168. num_iter = 0
  169. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  170. # in this example, each dictionary has keys "image" and "label"
  171. logger.info("image is {}".format(item["image"]))
  172. logger.info("label is {}".format(item["label"]))
  173. num_iter += 1
  174. logger.info("Number of data in data1: {}".format(num_iter))
  175. assert num_iter == 44
  176. def test_imagefolder_decode():
  177. logger.info("Test Case decode")
  178. # define parameters
  179. repeat_count = 1
  180. # apply dataset operations
  181. ext = [".jpg", ".JPEG"]
  182. data1 = ds.ImageFolderDataset(DATA_DIR, extensions=ext, decode=True)
  183. data1 = data1.repeat(repeat_count)
  184. num_iter = 0
  185. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  186. # in this example, each dictionary has keys "image" and "label"
  187. logger.info("image is {}".format(item["image"]))
  188. logger.info("label is {}".format(item["label"]))
  189. num_iter += 1
  190. logger.info("Number of data in data1: {}".format(num_iter))
  191. assert num_iter == 44
  192. def test_sequential_sampler():
  193. logger.info("Test Case SequentialSampler")
  194. golden = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  195. 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  196. 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
  197. 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
  198. # define parameters
  199. repeat_count = 1
  200. # apply dataset operations
  201. sampler = ds.SequentialSampler()
  202. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  203. data1 = data1.repeat(repeat_count)
  204. result = []
  205. num_iter = 0
  206. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  207. # in this example, each dictionary has keys "image" and "label"
  208. result.append(item["label"])
  209. num_iter += 1
  210. logger.info("Result: {}".format(result))
  211. assert result == golden
  212. def test_random_sampler():
  213. logger.info("Test Case RandomSampler")
  214. # define parameters
  215. repeat_count = 1
  216. # apply dataset operations
  217. sampler = ds.RandomSampler()
  218. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  219. data1 = data1.repeat(repeat_count)
  220. num_iter = 0
  221. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  222. # in this example, each dictionary has keys "image" and "label"
  223. logger.info("image is {}".format(item["image"]))
  224. logger.info("label is {}".format(item["label"]))
  225. num_iter += 1
  226. logger.info("Number of data in data1: {}".format(num_iter))
  227. assert num_iter == 44
  228. def test_distributed_sampler():
  229. logger.info("Test Case DistributedSampler")
  230. # define parameters
  231. repeat_count = 1
  232. # apply dataset operations
  233. sampler = ds.DistributedSampler(10, 1)
  234. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  235. data1 = data1.repeat(repeat_count)
  236. num_iter = 0
  237. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  238. # in this example, each dictionary has keys "image" and "label"
  239. logger.info("image is {}".format(item["image"]))
  240. logger.info("label is {}".format(item["label"]))
  241. num_iter += 1
  242. logger.info("Number of data in data1: {}".format(num_iter))
  243. assert num_iter == 5
  244. def test_pk_sampler():
  245. logger.info("Test Case PKSampler")
  246. # define parameters
  247. repeat_count = 1
  248. # apply dataset operations
  249. sampler = ds.PKSampler(3)
  250. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  251. data1 = data1.repeat(repeat_count)
  252. num_iter = 0
  253. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  254. # in this example, each dictionary has keys "image" and "label"
  255. logger.info("image is {}".format(item["image"]))
  256. logger.info("label is {}".format(item["label"]))
  257. num_iter += 1
  258. logger.info("Number of data in data1: {}".format(num_iter))
  259. assert num_iter == 12
  260. def test_subset_random_sampler():
  261. logger.info("Test Case SubsetRandomSampler")
  262. # define parameters
  263. repeat_count = 1
  264. # apply dataset operations
  265. indices = [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11]
  266. sampler = ds.SubsetRandomSampler(indices)
  267. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  268. data1 = data1.repeat(repeat_count)
  269. num_iter = 0
  270. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  271. # in this example, each dictionary has keys "image" and "label"
  272. logger.info("image is {}".format(item["image"]))
  273. logger.info("label is {}".format(item["label"]))
  274. num_iter += 1
  275. logger.info("Number of data in data1: {}".format(num_iter))
  276. assert num_iter == 12
  277. def test_weighted_random_sampler():
  278. logger.info("Test Case WeightedRandomSampler")
  279. # define parameters
  280. repeat_count = 1
  281. # apply dataset operations
  282. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  283. sampler = ds.WeightedRandomSampler(weights, 11)
  284. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  285. data1 = data1.repeat(repeat_count)
  286. num_iter = 0
  287. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  288. # in this example, each dictionary has keys "image" and "label"
  289. logger.info("image is {}".format(item["image"]))
  290. logger.info("label is {}".format(item["label"]))
  291. num_iter += 1
  292. logger.info("Number of data in data1: {}".format(num_iter))
  293. assert num_iter == 11
  294. def test_weighted_random_sampler_exception():
  295. """
  296. Test error cases for WeightedRandomSampler
  297. """
  298. logger.info("Test error cases for WeightedRandomSampler")
  299. error_msg_1 = "type of weights element should be number"
  300. with pytest.raises(TypeError, match=error_msg_1):
  301. weights = ""
  302. ds.WeightedRandomSampler(weights)
  303. error_msg_2 = "type of weights element should be number"
  304. with pytest.raises(TypeError, match=error_msg_2):
  305. weights = (0.9, 0.8, 1.1)
  306. ds.WeightedRandomSampler(weights)
  307. error_msg_3 = "weights size should not be 0"
  308. with pytest.raises(ValueError, match=error_msg_3):
  309. weights = []
  310. ds.WeightedRandomSampler(weights)
  311. error_msg_4 = "weights should not contain negative numbers"
  312. with pytest.raises(ValueError, match=error_msg_4):
  313. weights = [1.0, 0.1, 0.02, 0.3, -0.4]
  314. ds.WeightedRandomSampler(weights)
  315. error_msg_5 = "elements of weights should not be all zero"
  316. with pytest.raises(ValueError, match=error_msg_5):
  317. weights = [0, 0, 0, 0, 0]
  318. ds.WeightedRandomSampler(weights)
  319. def test_imagefolder_rename():
  320. logger.info("Test Case rename")
  321. # define parameters
  322. repeat_count = 1
  323. # apply dataset operations
  324. data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=10)
  325. data1 = data1.repeat(repeat_count)
  326. num_iter = 0
  327. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  328. # in this example, each dictionary has keys "image" and "label"
  329. logger.info("image is {}".format(item["image"]))
  330. logger.info("label is {}".format(item["label"]))
  331. num_iter += 1
  332. logger.info("Number of data in data1: {}".format(num_iter))
  333. assert num_iter == 10
  334. data1 = data1.rename(input_columns=["image"], output_columns="image2")
  335. num_iter = 0
  336. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  337. # in this example, each dictionary has keys "image" and "label"
  338. logger.info("image is {}".format(item["image2"]))
  339. logger.info("label is {}".format(item["label"]))
  340. num_iter += 1
  341. logger.info("Number of data in data1: {}".format(num_iter))
  342. assert num_iter == 10
  343. def test_imagefolder_zip():
  344. logger.info("Test Case zip")
  345. # define parameters
  346. repeat_count = 2
  347. # apply dataset operations
  348. data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=10)
  349. data2 = ds.ImageFolderDataset(DATA_DIR, num_samples=10)
  350. data1 = data1.repeat(repeat_count)
  351. # rename dataset2 for no conflict
  352. data2 = data2.rename(input_columns=["image", "label"], output_columns=["image1", "label1"])
  353. data3 = ds.zip((data1, data2))
  354. num_iter = 0
  355. for item in data3.create_dict_iterator(num_epochs=1): # each data is a dictionary
  356. # in this example, each dictionary has keys "image" and "label"
  357. logger.info("image is {}".format(item["image"]))
  358. logger.info("label is {}".format(item["label"]))
  359. num_iter += 1
  360. logger.info("Number of data in data1: {}".format(num_iter))
  361. assert num_iter == 10
  362. if __name__ == '__main__':
  363. test_imagefolder_basic()
  364. logger.info('test_imagefolder_basic Ended.\n')
  365. test_imagefolder_numsamples()
  366. logger.info('test_imagefolder_numsamples Ended.\n')
  367. test_sequential_sampler()
  368. logger.info('test_sequential_sampler Ended.\n')
  369. test_random_sampler()
  370. logger.info('test_random_sampler Ended.\n')
  371. test_distributed_sampler()
  372. logger.info('test_distributed_sampler Ended.\n')
  373. test_pk_sampler()
  374. logger.info('test_pk_sampler Ended.\n')
  375. test_subset_random_sampler()
  376. logger.info('test_subset_random_sampler Ended.\n')
  377. test_weighted_random_sampler()
  378. logger.info('test_weighted_random_sampler Ended.\n')
  379. test_weighted_random_sampler_exception()
  380. logger.info('test_weighted_random_sampler_exception Ended.\n')
  381. test_imagefolder_numshards()
  382. logger.info('test_imagefolder_numshards Ended.\n')
  383. test_imagefolder_shardid()
  384. logger.info('test_imagefolder_shardid Ended.\n')
  385. test_imagefolder_noshuffle()
  386. logger.info('test_imagefolder_noshuffle Ended.\n')
  387. test_imagefolder_extrashuffle()
  388. logger.info('test_imagefolder_extrashuffle Ended.\n')
  389. test_imagefolder_classindex()
  390. logger.info('test_imagefolder_classindex Ended.\n')
  391. test_imagefolder_negative_classindex()
  392. logger.info('test_imagefolder_negative_classindex Ended.\n')
  393. test_imagefolder_extensions()
  394. logger.info('test_imagefolder_extensions Ended.\n')
  395. test_imagefolder_decode()
  396. logger.info('test_imagefolder_decode Ended.\n')
  397. test_imagefolder_rename()
  398. logger.info('test_imagefolder_rename Ended.\n')
  399. test_imagefolder_zip()
  400. logger.info('test_imagefolder_zip Ended.\n')