You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_sampler_chain.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import mindspore.dataset as ds
  17. import mindspore.dataset.transforms.c_transforms as c_transforms
  18. from mindspore import log as logger
  19. from util import save_and_check_md5
  20. GENERATE_GOLDEN = False
  21. IMAGENET_RAWDATA_DIR = "../data/dataset/testImageNetData2/train"
  22. IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
  23. "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
  24. "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
  25. "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
  26. MNIST_DATA_DIR = "../data/dataset/testMnistData"
  27. MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest"
  28. CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data"
  29. COCO_DATA_DIR = "../data/dataset/testCOCO/train/"
  30. ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json"
  31. VOC_DATA_DIR = "../data/dataset/testVOC2012"
  32. def test_numpyslices_sampler_no_chain():
  33. """
  34. Test NumpySlicesDataset with sampler, no chain
  35. """
  36. logger.info("test_numpyslices_sampler_no_chain")
  37. # Create NumpySlicesDataset with sampler, no chain
  38. np_data = [1, 2, 3, 4]
  39. sampler = ds.SequentialSampler(start_index=1, num_samples=2)
  40. data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
  41. # Verify dataset size
  42. data1_size = data1.get_dataset_size()
  43. logger.info("dataset size is: {}".format(data1_size))
  44. assert data1_size == 2
  45. # Verify number of rows
  46. assert sum([1 for _ in data1]) == 2
  47. # Verify dataset contents
  48. res = []
  49. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  50. logger.info("item: {}".format(item))
  51. res.append(item)
  52. logger.info("dataset: {}".format(res))
  53. def test_numpyslices_sampler_chain():
  54. """
  55. Test NumpySlicesDataset sampler chain
  56. """
  57. logger.info("test_numpyslices_sampler_chain")
  58. # Create NumpySlicesDataset with sampler chain
  59. # Use 1 statement to add child sampler
  60. np_data = [1, 2, 3, 4]
  61. sampler = ds.SequentialSampler(start_index=1, num_samples=2)
  62. sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
  63. data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
  64. # Verify dataset size
  65. data1_size = data1.get_dataset_size()
  66. logger.info("dataset size is: {}".format(data1_size))
  67. assert data1_size == 1
  68. # Verify number of rows
  69. assert sum([1 for _ in data1]) == 1
  70. # Verify dataset contents
  71. res = []
  72. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  73. logger.info("item: {}".format(item))
  74. res.append(item)
  75. logger.info("dataset: {}".format(res))
  76. def test_numpyslices_sampler_chain2():
  77. """
  78. Test NumpySlicesDataset sampler chain
  79. """
  80. logger.info("test_numpyslices_sampler_chain2")
  81. # Create NumpySlicesDataset with sampler chain
  82. # Use 2 statements to add child sampler
  83. np_data = [1, 2, 3, 4]
  84. sampler = ds.SequentialSampler(start_index=1, num_samples=1)
  85. child_sampler = ds.SequentialSampler(start_index=1, num_samples=2)
  86. sampler.add_child(child_sampler)
  87. data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
  88. # Verify dataset size
  89. data1_size = data1.get_dataset_size()
  90. logger.info("dataset size is: {}".format(data1_size))
  91. assert data1_size == 1
  92. # Verify number of rows
  93. assert sum([1 for _ in data1]) == 1
  94. # Verify dataset contents
  95. res = []
  96. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  97. logger.info("item: {}".format(item))
  98. res.append(item)
  99. logger.info("dataset: {}".format(res))
  100. def test_imagefolder_sampler_chain():
  101. """
  102. Test ImageFolderDataset sampler chain
  103. """
  104. logger.info("test_imagefolder_sampler_chain")
  105. sampler = ds.SequentialSampler(start_index=1, num_samples=3)
  106. child_sampler = ds.PKSampler(2)
  107. sampler.add_child(child_sampler)
  108. data1 = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, sampler=sampler)
  109. # Verify dataset size
  110. data1_size = data1.get_dataset_size()
  111. logger.info("dataset size is: {}".format(data1_size))
  112. assert data1_size == 3
  113. # Verify number of rows
  114. assert sum([1 for _ in data1]) == 3
  115. # Verify dataset contents
  116. res = []
  117. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  118. logger.info("item: {}".format(item))
  119. res.append(item)
  120. logger.info("dataset: {}".format(res))
  121. def test_mnist_sampler_chain():
  122. """
  123. Test Mnist sampler chain
  124. """
  125. logger.info("test_mnist_sampler_chain")
  126. sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1)
  127. child_sampler = ds.RandomSampler(replacement=True, num_samples=4)
  128. sampler.add_child(child_sampler)
  129. data1 = ds.MnistDataset(MNIST_DATA_DIR, sampler=sampler)
  130. # Verify dataset size
  131. data1_size = data1.get_dataset_size()
  132. logger.info("dataset size is: {}".format(data1_size))
  133. assert data1_size == 3
  134. # Verify number of rows
  135. assert sum([1 for _ in data1]) == 3
  136. # Verify dataset contents
  137. res = []
  138. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  139. logger.info("item: {}".format(item))
  140. res.append(item)
  141. logger.info("dataset: {}".format(res))
  142. def test_manifest_sampler_chain():
  143. """
  144. Test Manifest sampler chain
  145. """
  146. logger.info("test_manifest_sampler_chain")
  147. sampler = ds.RandomSampler(replacement=True, num_samples=2)
  148. child_sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1)
  149. sampler.add_child(child_sampler)
  150. data1 = ds.ManifestDataset(MANIFEST_DATA_FILE, sampler=sampler)
  151. # Verify dataset size
  152. data1_size = data1.get_dataset_size()
  153. logger.info("dataset size is: {}".format(data1_size))
  154. assert data1_size == 2
  155. # Verify number of rows
  156. assert sum([1 for _ in data1]) == 2
  157. # Verify dataset contents
  158. res = []
  159. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  160. logger.info("item: {}".format(item))
  161. res.append(item)
  162. logger.info("dataset: {}".format(res))
  163. def test_coco_sampler_chain():
  164. """
  165. Test Coco sampler chain
  166. """
  167. logger.info("test_coco_sampler_chain")
  168. sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
  169. child_sampler = ds.RandomSampler(replacement=True, num_samples=2)
  170. sampler.add_child(child_sampler)
  171. data1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True,
  172. sampler=sampler)
  173. # Verify dataset size
  174. data1_size = data1.get_dataset_size()
  175. logger.info("dataset size is: {}".format(data1_size))
  176. assert data1_size == 1
  177. # Verify number of rows
  178. assert sum([1 for _ in data1]) == 1
  179. # Verify dataset contents
  180. res = []
  181. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  182. logger.info("item: {}".format(item))
  183. res.append(item)
  184. logger.info("dataset: {}".format(res))
  185. def test_cifar_sampler_chain():
  186. """
  187. Test Cifar sampler chain
  188. """
  189. logger.info("test_cifar_sampler_chain")
  190. sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
  191. child_sampler = ds.RandomSampler(replacement=True, num_samples=4)
  192. child_sampler2 = ds.SequentialSampler(start_index=0, num_samples=2)
  193. child_sampler.add_child(child_sampler2)
  194. sampler.add_child(child_sampler)
  195. data1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, sampler=sampler)
  196. # Verify dataset size
  197. data1_size = data1.get_dataset_size()
  198. logger.info("dataset size is: {}".format(data1_size))
  199. assert data1_size == 1
  200. # Verify number of rows
  201. assert sum([1 for _ in data1]) == 1
  202. # Verify dataset contents
  203. res = []
  204. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  205. logger.info("item: {}".format(item))
  206. res.append(item)
  207. logger.info("dataset: {}".format(res))
  208. def test_voc_sampler_chain():
  209. """
  210. Test VOC sampler chain
  211. """
  212. logger.info("test_voc_sampler_chain")
  213. sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
  214. child_sampler = ds.SequentialSampler(start_index=0)
  215. sampler.add_child(child_sampler)
  216. data1 = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", sampler=sampler)
  217. # Verify dataset size
  218. data1_size = data1.get_dataset_size()
  219. logger.info("dataset size is: {}".format(data1_size))
  220. assert data1_size == 5
  221. # Verify number of rows
  222. assert sum([1 for _ in data1]) == 5
  223. # Verify dataset contents
  224. res = []
  225. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  226. logger.info("item: {}".format(item))
  227. res.append(item)
  228. logger.info("dataset: {}".format(res))
  229. def test_numpyslices_sampler_chain_batch():
  230. """
  231. Test NumpySlicesDataset sampler chaining, with batch
  232. """
  233. logger.info("test_numpyslices_sampler_chain_batch")
  234. # Create NumpySlicesDataset with sampler chain
  235. np_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  236. sampler = ds.SequentialSampler(start_index=1, num_samples=3)
  237. sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
  238. data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
  239. data1 = data1.batch(batch_size=3, drop_remainder=False)
  240. # Verify dataset size
  241. data1_size = data1.get_dataset_size()
  242. logger.info("dataset size is: {}".format(data1_size))
  243. assert data1_size == 4
  244. # Verify number of rows
  245. assert sum([1 for _ in data1]) == 4
  246. # Verify dataset contents
  247. res = []
  248. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  249. logger.info("item: {}".format(item))
  250. res.append(item)
  251. logger.info("dataset: {}".format(res))
  252. def test_sampler_chain_errors():
  253. """
  254. Test error cases for sampler chains
  255. """
  256. logger.info("test_sampler_chain_errors")
  257. error_msg_1 = "'NoneType' object has no attribute 'add_child'"
  258. # Test add child sampler within child sampler
  259. sampler = ds.SequentialSampler(start_index=1, num_samples=2)
  260. sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
  261. with pytest.raises(AttributeError, match=error_msg_1):
  262. sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
  263. # error_msg_2 = "'NoneType' object has no attribute 'add_child'"
  264. # Test add second and nested child sampler
  265. sampler = ds.SequentialSampler(start_index=1, num_samples=2)
  266. child_sampler = ds.SequentialSampler(start_index=1, num_samples=2)
  267. sampler.add_child(child_sampler)
  268. child_sampler2 = ds.SequentialSampler(start_index=1, num_samples=2)
  269. sampler.add_child(child_sampler2)
  270. # FIXME - no error is raised; uncomment after code issue is resolved
  271. # with pytest.raises(AttributeError, match=error_msg_2):
  272. # sampler.add_child(child_sampler2)
  273. # np_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  274. # data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
  275. error_msg_3 = "Conflicting arguments during sampler assignments."
  276. # Test conflicting arguments (sampler and shuffle=False) for sampler (no chain)
  277. np_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  278. sampler = ds.SequentialSampler(start_index=1, num_samples=3)
  279. with pytest.raises(ValueError, match=error_msg_3):
  280. ds.NumpySlicesDataset(np_data, shuffle=False, sampler=sampler)
  281. # error_msg_4 = "Conflicting arguments during sampler assignments."
  282. # Test conflicting arguments (sampler and shuffle=False) for sampler chaining
  283. np_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  284. sampler = ds.SequentialSampler(start_index=1, num_samples=3)
  285. sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
  286. # FIXME - no error is raised; uncomment after code issue is resolved
  287. # with pytest.raises(ValueError, match=error_msg_4):
  288. # ds.NumpySlicesDataset(np_data, shuffle=False, sampler=sampler)
  289. def test_manifest_sampler_chain_repeat():
  290. """
  291. Test ManifestDataset sampler chain DistributedSampler->SequentialSampler, with repeat
  292. """
  293. logger.info("test_manifest_sampler_chain_batch")
  294. manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
  295. # Create sampler chain DistributedSampler->SequentialSampler
  296. sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=5)
  297. child_sampler = ds.SequentialSampler()
  298. sampler.add_child(child_sampler)
  299. # Create ManifestDataset with sampler chain
  300. data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
  301. data1 = data1.repeat(count=2)
  302. # Verify dataset size
  303. data1_size = data1.get_dataset_size()
  304. logger.info("dataset size is: {}".format(data1_size))
  305. assert data1_size == 10
  306. # Verify number of rows
  307. assert sum([1 for _ in data1]) == 10
  308. # Verify dataset contents
  309. filename = "sampler_chain_manifest_repeat_result.npz"
  310. save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
  311. def test_manifest_sampler_chain_batch_repeat():
  312. """
  313. Test ManifestDataset sampler chain DistributedSampler->SequentialSampler, with batch then repeat
  314. """
  315. logger.info("test_manifest_sampler_chain_batch_repeat")
  316. manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
  317. # Create sampler chain DistributedSampler->SequentialSampler
  318. sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=5)
  319. child_sampler = ds.SequentialSampler()
  320. sampler.add_child(child_sampler)
  321. # Create ManifestDataset with sampler chain
  322. data1 = ds.ManifestDataset(manifest_file, decode=True, sampler=sampler)
  323. one_hot_encode = c_transforms.OneHot(3)
  324. data1 = data1.map(operations=one_hot_encode, input_columns=["label"])
  325. data1 = data1.batch(batch_size=5, drop_remainder=False)
  326. data1 = data1.repeat(count=2)
  327. # Verify dataset size
  328. data1_size = data1.get_dataset_size()
  329. logger.info("dataset size is: {}".format(data1_size))
  330. assert data1_size == 2
  331. # Verify number of rows
  332. # FIXME: Uncomment the following assert when code issue is resolved
  333. # assert sum([1 for _ in data1]) == 2
  334. if __name__ == '__main__':
  335. test_numpyslices_sampler_no_chain()
  336. test_numpyslices_sampler_chain()
  337. test_numpyslices_sampler_chain2()
  338. test_imagefolder_sampler_chain()
  339. test_mnist_sampler_chain()
  340. test_manifest_sampler_chain()
  341. test_coco_sampler_chain()
  342. test_cifar_sampler_chain()
  343. test_voc_sampler_chain()
  344. test_numpyslices_sampler_chain_batch()
  345. test_sampler_chain_errors()
  346. test_manifest_sampler_chain_repeat()
  347. test_manifest_sampler_chain_batch_repeat()