You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_imagefolder.py 30 kB

added python api based on cpp api 1st draft of python iterator Added Cifar10 and Cifar100 pybind port Change pybind to use IR for Skip and Manifest Signed-off-by: alex-yuyue <yue.yu1@huawei.com> DatasetNode as a base for all IR nodes namespace change Fix the namespace issue and make ut tests work Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Add VOCDataset !63 Added RandomDataset * Added RandomDataset add imagefolder ir Pybind switch: CelebA and UT !61 CLUE example with class definition * Merge branch 'python-api' of gitee.com:ezphlow/mindspore into clue_class_pybind * Passing testcases * Added CLUE, not working add ManifestDataset IR Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Update Coco & VOC & TFReader, Update clang-format, Reorder datasets_binding !69 Add Generator and move c_dataset.Iterator to dataset.Iterator * Add GeneratorDataset to c_dataset * Add GeneratorDataset to c_dataset !67 Moving c_datasets and adding sampler wrapper * Need to add create() method in datasets.py * migration from c_dataset to dataset part 1 !71 Fix indent error * Fix indentation error !72 Fix c_api tests cases * Fix c_api tests cases !73 Added CSV Dataset * Added CSVDataset pybind switch: Take and CelebA fixes !75 move c_dataset functionality to datasets * Fixed existing testcases * Added working clue and imagefolder * Added sampler conversion from pybind * Added sampler creation !77 Add Python API tree * Python API tree add minddataset TextFileDataset pybind Rename to skip test_concat.py and test_minddataset_exception.py !80 Add batch IR to python-api branch, most test cases work * staging III * staging, add pybind Enable more c_api take and CelebA tests; delete util_c_api !84 Schema changes in datasets.py * Schema changes !85 Remove input_indexes from sub-classes * remove input_index from each subclass !83 Remove C datasets * Removed c_dataset package * Remove c_datasets !82 pybind switch: shuffle * pybind switch: shuffle !86 Add build_vocab * Add build_vocab Rebase with upstream/master _shuffle conflict BatchNode error !88 Fix rebase problem * fix rebase problem Enable more unit tests; code typo/nit fixes !91 Fix python vocag hang * Fix python vocab hang !89 Added BucketBatchByLength Pybind switch * Added BucketBatchByLength Update and enable more tet_c_api_*.py tests !95 Add BuildSentencePeiceVocab * - Add BuildSentencePeiceVocab !96 Fix more tests * - Fix some tests - Enable more test_c_api_* - Add syncwait !99 pybind switch for device op * pybind switch for device op !93 Add getters to python API * Add getters to python API !101 Validate tree, error if graph * - Add sync wait !103 TFrecord/Random Datasets schema problem * - TfRecord/Random schem aproblem !102 Added filter pybind switch * Added Filter pybind switch !104 Fix num_samples * - TfRecord/Random schem aproblem !105 Fix to_device hang * Fix to_device hang !94 Adds Cache support for CLUE dataset * Added cache for all dataset ops * format change * Added CLUE cache support * Added Cache conversion Add save pybind fix compile err init modify concat_node !107 Fix some tests cases * Fix tests cases Enable and fix more tests !109 pybind switch for get dataset size * pybind_get_dataset_size some check-code fixes for pylint, cpplint and clang-format !113 Add callback * revert * dataset_sz 1 line * fix typo * get callback to work !114 Make Android compile clean * Make Android Compile Clean Fix build issues due to rebase !115 Fix more tests * Fix tests cases * !93 Add getters to python API fix test_profiling.py !116 fix get dataset size * fix get dataset size !117 GetColumnNames pybind switch * Added GetColumnNames pybind switch code-check fixes: clangformat, cppcheck, cpplint, pylint Delete duplicate test_c_api_*.py files; more lint fixes !121 Fix cpp tests * Remove extra call to getNext in cpp tests !122 Fix Schema with Generator * Fix Schema with Generator fix some cases of csv & mindrecord !124 fix tfrecord get_dataset_size and add some UTs * fix tfrecord get dataset size and add some ut for get_dataset_size !125 getter separation * Getter separation !126 Fix sampler.GetNumSamples * Fix sampler.GetNumSampler !127 Assign runtime getter to each get function * Assign runtime getter to each get function Fix compile issues !128 Match master code * Match master code !129 Cleanup DeviceOp/save code * Cleanup ToDevice/Save code !130 Add cache fix * Added cache fix for map and image folder !132 Fix testing team issues * Pass queue_name from python to C++ * Add Schema.from_json !131 Fix Cache op issues and delete de_pipeline * Roll back C++ change * Removed de_pipeline and passing all cache tests. * fixed cache tests !134 Cleanup datasets.py part1 * Cleanup dataset.py part1 !133 Updated validation for SentencePieceVocab.from_dataset * Added type_check for column names in SentencePieceVocab.from_dataset Rebase on master 181120 10:20 fix profiling temporary solution of catching stauts from Node.Build() !141 ToDevice Termination * ToDevice termination pylint fixes !137 Fix test team issues and add some corresponding tests * Fix test team issues and add some corresponding tests !138 TreeGetter changes to use OptPass * Getter changes to use OptPass (Zirui) Rebase fix !143 Fix cpplint issue * Fix cpplint issue pylint fixes in updated testcases !145 Reset exceptions testcase * reset exception test to master !146 Fix Check_Pylint Error * Fix Check_Pylint Error !147 fix android * fix android !148 ToDevice changes * Add ToDevice to the iterator List for cleanup at exit !149 Pylint issue * Add ToDevice to the iterator List for cleanup at exit !150 Pylint 2 * Add ToDevice to the iterator List for cleanup at exit !152 ExecutionTree error * ET destructor error !153 in getter_pass, only remove callback, without deleting map op * getter pass no longer removes map !156 early __del__ of iterator/to_device * early __del__ of iterator !155 Address review comments Eric 1 * Added one liner fix to validators.py * roll back signature fix * lint fix * Eric Address comments 2 * C++ lint fix * Address comments Eric 1 !158 Review rework for dataset bindings - part 1 * Reorder nodes repeat and rename * Review rework for dataset bindings - part 1 !154 Fixing minor problems in the comments (datasets.py, python_tree_consumer.cc, iterators_bindings.cc, and iterators.py) * Fixing minor problems in the comments (datasets.py, python_tree_consum… !157 add replace none * Add replace_none to datasets.py, address comments in tests Trying to resolve copy Override the deepcopy method of deviceop Create_ir_tree method Create_ir_tree method 2 Create_ir_tree method 2 del to_device if already exists del to_device if already exists cache getters shapes and types Added yolov3 relaxation, to be rolled back Get shapes and types together bypass yolo NumWorkers for MapOp revert Yolo revert Thor Print more info Debug code: Update LOG INFO to LOG ERROR do not remove epochctrl for getter pass Remove repeat(1) pritn batch size add log to tree_consumer and device_queue op Revert PR 8744 Signed-off-by: alex-yuyue <yue.yu1@huawei.com> __del__ toDEvice __del__ toDevice2 !165 add ifndef ENABLE_ANDROID to device queue print * Add ifndef ENABLE_ANDROID to device queue print revert some changes !166 getter: get_data_info * getter: get_data_info !168 add back tree print * revert info to warnning in one log * add back the missed print tree log Release GIL in GetDataInfo
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846
  1. # Copyright 2019-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import mindspore.dataset as ds
  17. import mindspore.dataset.vision.c_transforms as vision
  18. from mindspore import log as logger
  19. DATA_DIR = "../data/dataset/testPK/data"
  20. def test_imagefolder_basic():
  21. logger.info("Test Case basic")
  22. # define parameters
  23. repeat_count = 1
  24. # apply dataset operations
  25. data1 = ds.ImageFolderDataset(DATA_DIR)
  26. data1 = data1.repeat(repeat_count)
  27. num_iter = 0
  28. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  29. # in this example, each dictionary has keys "image" and "label"
  30. logger.info("image is {}".format(item["image"]))
  31. logger.info("label is {}".format(item["label"]))
  32. num_iter += 1
  33. logger.info("Number of data in data1: {}".format(num_iter))
  34. assert num_iter == 44
  35. def test_imagefolder_numsamples():
  36. logger.info("Test Case numSamples")
  37. # define parameters
  38. repeat_count = 1
  39. # apply dataset operations
  40. data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=10, num_parallel_workers=2)
  41. data1 = data1.repeat(repeat_count)
  42. num_iter = 0
  43. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  44. # in this example, each dictionary has keys "image" and "label"
  45. logger.info("image is {}".format(item["image"]))
  46. logger.info("label is {}".format(item["label"]))
  47. num_iter += 1
  48. logger.info("Number of data in data1: {}".format(num_iter))
  49. assert num_iter == 10
  50. random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
  51. data1 = ds.ImageFolderDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
  52. num_iter = 0
  53. for item in data1.create_dict_iterator(num_epochs=1):
  54. num_iter += 1
  55. assert num_iter == 3
  56. random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
  57. data1 = ds.ImageFolderDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
  58. num_iter = 0
  59. for item in data1.create_dict_iterator(num_epochs=1):
  60. num_iter += 1
  61. assert num_iter == 3
  62. def test_imagefolder_numshards():
  63. logger.info("Test Case numShards")
  64. # define parameters
  65. repeat_count = 1
  66. # apply dataset operations
  67. data1 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=3)
  68. data1 = data1.repeat(repeat_count)
  69. num_iter = 0
  70. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  71. # in this example, each dictionary has keys "image" and "label"
  72. logger.info("image is {}".format(item["image"]))
  73. logger.info("label is {}".format(item["label"]))
  74. num_iter += 1
  75. logger.info("Number of data in data1: {}".format(num_iter))
  76. assert num_iter == 11
  77. def test_imagefolder_shardid():
  78. logger.info("Test Case withShardID")
  79. # define parameters
  80. repeat_count = 1
  81. # apply dataset operations
  82. data1 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=1)
  83. data1 = data1.repeat(repeat_count)
  84. num_iter = 0
  85. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  86. # in this example, each dictionary has keys "image" and "label"
  87. logger.info("image is {}".format(item["image"]))
  88. logger.info("label is {}".format(item["label"]))
  89. num_iter += 1
  90. logger.info("Number of data in data1: {}".format(num_iter))
  91. assert num_iter == 11
  92. def test_imagefolder_noshuffle():
  93. logger.info("Test Case noShuffle")
  94. # define parameters
  95. repeat_count = 1
  96. # apply dataset operations
  97. data1 = ds.ImageFolderDataset(DATA_DIR, shuffle=False)
  98. data1 = data1.repeat(repeat_count)
  99. num_iter = 0
  100. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  101. # in this example, each dictionary has keys "image" and "label"
  102. logger.info("image is {}".format(item["image"]))
  103. logger.info("label is {}".format(item["label"]))
  104. num_iter += 1
  105. logger.info("Number of data in data1: {}".format(num_iter))
  106. assert num_iter == 44
  107. def test_imagefolder_extrashuffle():
  108. logger.info("Test Case extraShuffle")
  109. # define parameters
  110. repeat_count = 2
  111. # apply dataset operations
  112. data1 = ds.ImageFolderDataset(DATA_DIR, shuffle=True)
  113. data1 = data1.shuffle(buffer_size=5)
  114. data1 = data1.repeat(repeat_count)
  115. num_iter = 0
  116. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  117. # in this example, each dictionary has keys "image" and "label"
  118. logger.info("image is {}".format(item["image"]))
  119. logger.info("label is {}".format(item["label"]))
  120. num_iter += 1
  121. logger.info("Number of data in data1: {}".format(num_iter))
  122. assert num_iter == 88
  123. def test_imagefolder_classindex():
  124. logger.info("Test Case classIndex")
  125. # define parameters
  126. repeat_count = 1
  127. # apply dataset operations
  128. class_index = {"class3": 333, "class1": 111}
  129. data1 = ds.ImageFolderDataset(DATA_DIR, class_indexing=class_index, shuffle=False)
  130. data1 = data1.repeat(repeat_count)
  131. golden = [111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
  132. 333, 333, 333, 333, 333, 333, 333, 333, 333, 333, 333]
  133. num_iter = 0
  134. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  135. # in this example, each dictionary has keys "image" and "label"
  136. logger.info("image is {}".format(item["image"]))
  137. logger.info("label is {}".format(item["label"]))
  138. assert item["label"] == golden[num_iter]
  139. num_iter += 1
  140. logger.info("Number of data in data1: {}".format(num_iter))
  141. assert num_iter == 22
  142. def test_imagefolder_negative_classindex():
  143. logger.info("Test Case negative classIndex")
  144. # define parameters
  145. repeat_count = 1
  146. # apply dataset operations
  147. class_index = {"class3": -333, "class1": 111}
  148. data1 = ds.ImageFolderDataset(DATA_DIR, class_indexing=class_index, shuffle=False)
  149. data1 = data1.repeat(repeat_count)
  150. golden = [111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
  151. -333, -333, -333, -333, -333, -333, -333, -333, -333, -333, -333]
  152. num_iter = 0
  153. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  154. # in this example, each dictionary has keys "image" and "label"
  155. logger.info("image is {}".format(item["image"]))
  156. logger.info("label is {}".format(item["label"]))
  157. assert item["label"] == golden[num_iter]
  158. num_iter += 1
  159. logger.info("Number of data in data1: {}".format(num_iter))
  160. assert num_iter == 22
  161. def test_imagefolder_extensions():
  162. logger.info("Test Case extensions")
  163. # define parameters
  164. repeat_count = 1
  165. # apply dataset operations
  166. ext = [".jpg", ".JPEG"]
  167. data1 = ds.ImageFolderDataset(DATA_DIR, extensions=ext)
  168. data1 = data1.repeat(repeat_count)
  169. num_iter = 0
  170. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  171. # in this example, each dictionary has keys "image" and "label"
  172. logger.info("image is {}".format(item["image"]))
  173. logger.info("label is {}".format(item["label"]))
  174. num_iter += 1
  175. logger.info("Number of data in data1: {}".format(num_iter))
  176. assert num_iter == 44
  177. def test_imagefolder_decode():
  178. logger.info("Test Case decode")
  179. # define parameters
  180. repeat_count = 1
  181. # apply dataset operations
  182. ext = [".jpg", ".JPEG"]
  183. data1 = ds.ImageFolderDataset(DATA_DIR, extensions=ext, decode=True)
  184. data1 = data1.repeat(repeat_count)
  185. num_iter = 0
  186. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  187. # in this example, each dictionary has keys "image" and "label"
  188. logger.info("image is {}".format(item["image"]))
  189. logger.info("label is {}".format(item["label"]))
  190. num_iter += 1
  191. logger.info("Number of data in data1: {}".format(num_iter))
  192. assert num_iter == 44
  193. def test_sequential_sampler():
  194. logger.info("Test Case SequentialSampler")
  195. golden = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  196. 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  197. 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
  198. 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
  199. # define parameters
  200. repeat_count = 1
  201. # apply dataset operations
  202. sampler = ds.SequentialSampler()
  203. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  204. data1 = data1.repeat(repeat_count)
  205. result = []
  206. num_iter = 0
  207. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  208. # in this example, each dictionary has keys "image" and "label"
  209. result.append(item["label"])
  210. num_iter += 1
  211. assert num_iter == 44
  212. logger.info("Result: {}".format(result))
  213. assert result == golden
  214. def test_random_sampler():
  215. logger.info("Test Case RandomSampler")
  216. # define parameters
  217. repeat_count = 1
  218. # apply dataset operations
  219. sampler = ds.RandomSampler()
  220. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  221. data1 = data1.repeat(repeat_count)
  222. num_iter = 0
  223. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  224. # in this example, each dictionary has keys "image" and "label"
  225. logger.info("image is {}".format(item["image"]))
  226. logger.info("label is {}".format(item["label"]))
  227. num_iter += 1
  228. logger.info("Number of data in data1: {}".format(num_iter))
  229. assert num_iter == 44
  230. def test_distributed_sampler():
  231. logger.info("Test Case DistributedSampler")
  232. # define parameters
  233. repeat_count = 1
  234. # apply dataset operations
  235. sampler = ds.DistributedSampler(10, 1)
  236. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  237. data1 = data1.repeat(repeat_count)
  238. num_iter = 0
  239. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  240. # in this example, each dictionary has keys "image" and "label"
  241. logger.info("image is {}".format(item["image"]))
  242. logger.info("label is {}".format(item["label"]))
  243. num_iter += 1
  244. logger.info("Number of data in data1: {}".format(num_iter))
  245. assert num_iter == 5
  246. def test_pk_sampler():
  247. logger.info("Test Case PKSampler")
  248. # define parameters
  249. repeat_count = 1
  250. # apply dataset operations
  251. sampler = ds.PKSampler(3)
  252. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  253. data1 = data1.repeat(repeat_count)
  254. num_iter = 0
  255. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  256. # in this example, each dictionary has keys "image" and "label"
  257. logger.info("image is {}".format(item["image"]))
  258. logger.info("label is {}".format(item["label"]))
  259. num_iter += 1
  260. logger.info("Number of data in data1: {}".format(num_iter))
  261. assert num_iter == 12
  262. def test_subset_random_sampler():
  263. logger.info("Test Case SubsetRandomSampler")
  264. # define parameters
  265. repeat_count = 1
  266. # apply dataset operations
  267. indices = [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11]
  268. sampler = ds.SubsetRandomSampler(indices)
  269. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  270. data1 = data1.repeat(repeat_count)
  271. num_iter = 0
  272. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  273. # in this example, each dictionary has keys "image" and "label"
  274. logger.info("image is {}".format(item["image"]))
  275. logger.info("label is {}".format(item["label"]))
  276. num_iter += 1
  277. logger.info("Number of data in data1: {}".format(num_iter))
  278. assert num_iter == 12
  279. def test_weighted_random_sampler():
  280. logger.info("Test Case WeightedRandomSampler")
  281. # define parameters
  282. repeat_count = 1
  283. # apply dataset operations
  284. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  285. sampler = ds.WeightedRandomSampler(weights, 11)
  286. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  287. data1 = data1.repeat(repeat_count)
  288. num_iter = 0
  289. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  290. # in this example, each dictionary has keys "image" and "label"
  291. logger.info("image is {}".format(item["image"]))
  292. logger.info("label is {}".format(item["label"]))
  293. num_iter += 1
  294. logger.info("Number of data in data1: {}".format(num_iter))
  295. assert num_iter == 11
  296. def test_weighted_random_sampler_exception():
  297. """
  298. Test error cases for WeightedRandomSampler
  299. """
  300. logger.info("Test error cases for WeightedRandomSampler")
  301. error_msg_1 = "type of weights element must be number"
  302. with pytest.raises(TypeError, match=error_msg_1):
  303. weights = ""
  304. ds.WeightedRandomSampler(weights)
  305. error_msg_2 = "type of weights element must be number"
  306. with pytest.raises(TypeError, match=error_msg_2):
  307. weights = (0.9, 0.8, 1.1)
  308. ds.WeightedRandomSampler(weights)
  309. error_msg_3 = "WeightedRandomSampler: weights vector must not be empty"
  310. with pytest.raises(RuntimeError, match=error_msg_3):
  311. weights = []
  312. sampler = ds.WeightedRandomSampler(weights)
  313. sampler.parse()
  314. error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative number, got: "
  315. with pytest.raises(RuntimeError, match=error_msg_4):
  316. weights = [1.0, 0.1, 0.02, 0.3, -0.4]
  317. sampler = ds.WeightedRandomSampler(weights)
  318. sampler.parse()
  319. error_msg_5 = "WeightedRandomSampler: elements of weights vector must not be all zero"
  320. with pytest.raises(RuntimeError, match=error_msg_5):
  321. weights = [0, 0, 0, 0, 0]
  322. sampler = ds.WeightedRandomSampler(weights)
  323. sampler.parse()
  324. def test_chained_sampler_01():
  325. logger.info("Test Case Chained Sampler - Random and Sequential, with repeat")
  326. # Create chained sampler, random and sequential
  327. sampler = ds.RandomSampler()
  328. child_sampler = ds.SequentialSampler()
  329. sampler.add_child(child_sampler)
  330. # Create ImageFolderDataset with sampler
  331. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  332. data1 = data1.repeat(count=3)
  333. # Verify dataset size
  334. data1_size = data1.get_dataset_size()
  335. logger.info("dataset size is: {}".format(data1_size))
  336. assert data1_size == 132
  337. # Verify number of iterations
  338. num_iter = 0
  339. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  340. # in this example, each dictionary has keys "image" and "label"
  341. logger.info("image is {}".format(item["image"]))
  342. logger.info("label is {}".format(item["label"]))
  343. num_iter += 1
  344. logger.info("Number of data in data1: {}".format(num_iter))
  345. assert num_iter == 132
  346. def test_chained_sampler_02():
  347. logger.info("Test Case Chained Sampler - Random and Sequential, with batch then repeat")
  348. # Create chained sampler, random and sequential
  349. sampler = ds.RandomSampler()
  350. child_sampler = ds.SequentialSampler()
  351. sampler.add_child(child_sampler)
  352. # Create ImageFolderDataset with sampler
  353. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  354. data1 = data1.batch(batch_size=5, drop_remainder=True)
  355. data1 = data1.repeat(count=2)
  356. # Verify dataset size
  357. data1_size = data1.get_dataset_size()
  358. logger.info("dataset size is: {}".format(data1_size))
  359. assert data1_size == 16
  360. # Verify number of iterations
  361. num_iter = 0
  362. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  363. # in this example, each dictionary has keys "image" and "label"
  364. logger.info("image is {}".format(item["image"]))
  365. logger.info("label is {}".format(item["label"]))
  366. num_iter += 1
  367. logger.info("Number of data in data1: {}".format(num_iter))
  368. assert num_iter == 16
  369. def test_chained_sampler_03():
  370. logger.info("Test Case Chained Sampler - Random and Sequential, with repeat then batch")
  371. # Create chained sampler, random and sequential
  372. sampler = ds.RandomSampler()
  373. child_sampler = ds.SequentialSampler()
  374. sampler.add_child(child_sampler)
  375. # Create ImageFolderDataset with sampler
  376. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  377. data1 = data1.repeat(count=2)
  378. data1 = data1.batch(batch_size=5, drop_remainder=False)
  379. # Verify dataset size
  380. data1_size = data1.get_dataset_size()
  381. logger.info("dataset size is: {}".format(data1_size))
  382. assert data1_size == 18
  383. # Verify number of iterations
  384. num_iter = 0
  385. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  386. # in this example, each dictionary has keys "image" and "label"
  387. logger.info("image is {}".format(item["image"]))
  388. logger.info("label is {}".format(item["label"]))
  389. num_iter += 1
  390. logger.info("Number of data in data1: {}".format(num_iter))
  391. assert num_iter == 18
  392. def test_chained_sampler_04():
  393. logger.info("Test Case Chained Sampler - Distributed and Random, with batch then repeat")
  394. # Create chained sampler, distributed and random
  395. sampler = ds.DistributedSampler(num_shards=4, shard_id=3)
  396. child_sampler = ds.RandomSampler()
  397. sampler.add_child(child_sampler)
  398. # Create ImageFolderDataset with sampler
  399. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  400. data1 = data1.batch(batch_size=5, drop_remainder=True)
  401. data1 = data1.repeat(count=3)
  402. # Verify dataset size
  403. data1_size = data1.get_dataset_size()
  404. logger.info("dataset size is: {}".format(data1_size))
  405. assert data1_size == 6
  406. # Verify number of iterations
  407. num_iter = 0
  408. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  409. # in this example, each dictionary has keys "image" and "label"
  410. logger.info("image is {}".format(item["image"]))
  411. logger.info("label is {}".format(item["label"]))
  412. num_iter += 1
  413. logger.info("Number of data in data1: {}".format(num_iter))
  414. # Note: Each of the 4 shards has 44/4=11 samples
  415. # Note: Number of iterations is (11/5 = 2) * 3 = 6
  416. assert num_iter == 6
  417. def skip_test_chained_sampler_05():
  418. logger.info("Test Case Chained Sampler - PKSampler and WeightedRandom")
  419. # Create chained sampler, PKSampler and WeightedRandom
  420. sampler = ds.PKSampler(num_val=3) # Number of elements per class is 3 (and there are 4 classes)
  421. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 0.5]
  422. child_sampler = ds.WeightedRandomSampler(weights, num_samples=12)
  423. sampler.add_child(child_sampler)
  424. # Create ImageFolderDataset with sampler
  425. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  426. # Verify dataset size
  427. data1_size = data1.get_dataset_size()
  428. logger.info("dataset size is: {}".format(data1_size))
  429. assert data1_size == 12
  430. # Verify number of iterations
  431. num_iter = 0
  432. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  433. # in this example, each dictionary has keys "image" and "label"
  434. logger.info("image is {}".format(item["image"]))
  435. logger.info("label is {}".format(item["label"]))
  436. num_iter += 1
  437. logger.info("Number of data in data1: {}".format(num_iter))
  438. # Note: PKSampler produces 4x3=12 samples
  439. # Note: Child WeightedRandomSampler produces 12 samples
  440. assert num_iter == 12
  441. def test_chained_sampler_06():
  442. logger.info("Test Case Chained Sampler - WeightedRandom and PKSampler")
  443. # Create chained sampler, WeightedRandom and PKSampler
  444. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 0.5]
  445. sampler = ds.WeightedRandomSampler(weights=weights, num_samples=12)
  446. child_sampler = ds.PKSampler(num_val=3) # Number of elements per class is 3 (and there are 4 classes)
  447. sampler.add_child(child_sampler)
  448. # Create ImageFolderDataset with sampler
  449. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  450. # Verify dataset size
  451. data1_size = data1.get_dataset_size()
  452. logger.info("dataset size is: {}".format(data1_size))
  453. assert data1_size == 12
  454. # Verify number of iterations
  455. num_iter = 0
  456. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  457. # in this example, each dictionary has keys "image" and "label"
  458. logger.info("image is {}".format(item["image"]))
  459. logger.info("label is {}".format(item["label"]))
  460. num_iter += 1
  461. logger.info("Number of data in data1: {}".format(num_iter))
  462. # Note: WeightedRandomSampler produces 12 samples
  463. # Note: Child PKSampler produces 12 samples
  464. assert num_iter == 12
  465. def test_chained_sampler_07():
  466. logger.info("Test Case Chained Sampler - SubsetRandom and Distributed, 2 shards")
  467. # Create chained sampler, subset random and distributed
  468. indices = [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11]
  469. sampler = ds.SubsetRandomSampler(indices, num_samples=12)
  470. child_sampler = ds.DistributedSampler(num_shards=2, shard_id=1)
  471. sampler.add_child(child_sampler)
  472. # Create ImageFolderDataset with sampler
  473. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  474. # Verify dataset size
  475. data1_size = data1.get_dataset_size()
  476. logger.info("dataset size is: {}".format(data1_size))
  477. assert data1_size == 12
  478. # Verify number of iterations
  479. num_iter = 0
  480. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  481. # in this example, each dictionary has keys "image" and "label"
  482. logger.info("image is {}".format(item["image"]))
  483. logger.info("label is {}".format(item["label"]))
  484. num_iter += 1
  485. logger.info("Number of data in data1: {}".format(num_iter))
  486. # Note: SubsetRandomSampler produces 12 samples
  487. # Note: Each of 2 shards has 6 samples
  488. # FIXME: Uncomment the following assert when code issue is resolved; at runtime, number of samples is 12 not 6
  489. # assert num_iter == 6
  490. def skip_test_chained_sampler_08():
  491. logger.info("Test Case Chained Sampler - SubsetRandom and Distributed, 4 shards")
  492. # Create chained sampler, subset random and distributed
  493. indices = [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11]
  494. sampler = ds.SubsetRandomSampler(indices, num_samples=12)
  495. child_sampler = ds.DistributedSampler(num_shards=4, shard_id=1)
  496. sampler.add_child(child_sampler)
  497. # Create ImageFolderDataset with sampler
  498. data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
  499. # Verify dataset size
  500. data1_size = data1.get_dataset_size()
  501. logger.info("dataset size is: {}".format(data1_size))
  502. assert data1_size == 3
  503. # Verify number of iterations
  504. num_iter = 0
  505. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  506. # in this example, each dictionary has keys "image" and "label"
  507. logger.info("image is {}".format(item["image"]))
  508. logger.info("label is {}".format(item["label"]))
  509. num_iter += 1
  510. logger.info("Number of data in data1: {}".format(num_iter))
  511. # Note: SubsetRandomSampler returns 12 samples
  512. # Note: Each of 4 shards has 3 samples
  513. assert num_iter == 3
  514. def test_imagefolder_rename():
  515. logger.info("Test Case rename")
  516. # define parameters
  517. repeat_count = 1
  518. # apply dataset operations
  519. data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=10)
  520. data1 = data1.repeat(repeat_count)
  521. num_iter = 0
  522. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  523. # in this example, each dictionary has keys "image" and "label"
  524. logger.info("image is {}".format(item["image"]))
  525. logger.info("label is {}".format(item["label"]))
  526. num_iter += 1
  527. logger.info("Number of data in data1: {}".format(num_iter))
  528. assert num_iter == 10
  529. data1 = data1.rename(input_columns=["image"], output_columns="image2")
  530. num_iter = 0
  531. for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  532. # in this example, each dictionary has keys "image" and "label"
  533. logger.info("image is {}".format(item["image2"]))
  534. logger.info("label is {}".format(item["label"]))
  535. num_iter += 1
  536. logger.info("Number of data in data1: {}".format(num_iter))
  537. assert num_iter == 10
  538. def test_imagefolder_zip():
  539. logger.info("Test Case zip")
  540. # define parameters
  541. repeat_count = 2
  542. # apply dataset operations
  543. data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=10)
  544. data2 = ds.ImageFolderDataset(DATA_DIR, num_samples=10)
  545. data1 = data1.repeat(repeat_count)
  546. # rename dataset2 for no conflict
  547. data2 = data2.rename(input_columns=["image", "label"], output_columns=["image1", "label1"])
  548. data3 = ds.zip((data1, data2))
  549. num_iter = 0
  550. for item in data3.create_dict_iterator(num_epochs=1): # each data is a dictionary
  551. # in this example, each dictionary has keys "image" and "label"
  552. logger.info("image is {}".format(item["image"]))
  553. logger.info("label is {}".format(item["label"]))
  554. num_iter += 1
  555. logger.info("Number of data in data1: {}".format(num_iter))
  556. assert num_iter == 10
  557. def test_imagefolder_exception():
  558. logger.info("Test imagefolder exception")
  559. def exception_func(item):
  560. raise Exception("Error occur!")
  561. def exception_func2(image, label):
  562. raise Exception("Error occur!")
  563. try:
  564. data = ds.ImageFolderDataset(DATA_DIR)
  565. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  566. for _ in data.__iter__():
  567. pass
  568. assert False
  569. except RuntimeError as e:
  570. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  571. try:
  572. data = ds.ImageFolderDataset(DATA_DIR)
  573. data = data.map(operations=exception_func2, input_columns=["image", "label"],
  574. output_columns=["image", "label", "label1"],
  575. column_order=["image", "label", "label1"], num_parallel_workers=1)
  576. for _ in data.__iter__():
  577. pass
  578. assert False
  579. except RuntimeError as e:
  580. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  581. try:
  582. data = ds.ImageFolderDataset(DATA_DIR)
  583. data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
  584. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  585. for _ in data.__iter__():
  586. pass
  587. assert False
  588. except RuntimeError as e:
  589. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  590. if __name__ == '__main__':
  591. test_imagefolder_basic()
  592. logger.info('test_imagefolder_basic Ended.\n')
  593. test_imagefolder_numsamples()
  594. logger.info('test_imagefolder_numsamples Ended.\n')
  595. test_sequential_sampler()
  596. logger.info('test_sequential_sampler Ended.\n')
  597. test_random_sampler()
  598. logger.info('test_random_sampler Ended.\n')
  599. test_distributed_sampler()
  600. logger.info('test_distributed_sampler Ended.\n')
  601. test_pk_sampler()
  602. logger.info('test_pk_sampler Ended.\n')
  603. test_subset_random_sampler()
  604. logger.info('test_subset_random_sampler Ended.\n')
  605. test_weighted_random_sampler()
  606. logger.info('test_weighted_random_sampler Ended.\n')
  607. test_weighted_random_sampler_exception()
  608. logger.info('test_weighted_random_sampler_exception Ended.\n')
  609. test_chained_sampler_01()
  610. logger.info('test_chained_sampler_01 Ended.\n')
  611. test_chained_sampler_02()
  612. logger.info('test_chained_sampler_02 Ended.\n')
  613. test_chained_sampler_03()
  614. logger.info('test_chained_sampler_03 Ended.\n')
  615. test_chained_sampler_04()
  616. logger.info('test_chained_sampler_04 Ended.\n')
  617. # test_chained_sampler_05()
  618. # logger.info('test_chained_sampler_05 Ended.\n')
  619. test_chained_sampler_06()
  620. logger.info('test_chained_sampler_06 Ended.\n')
  621. test_chained_sampler_07()
  622. logger.info('test_chained_sampler_07 Ended.\n')
  623. # test_chained_sampler_08()
  624. # logger.info('test_chained_sampler_07 Ended.\n')
  625. test_imagefolder_numshards()
  626. logger.info('test_imagefolder_numshards Ended.\n')
  627. test_imagefolder_shardid()
  628. logger.info('test_imagefolder_shardid Ended.\n')
  629. test_imagefolder_noshuffle()
  630. logger.info('test_imagefolder_noshuffle Ended.\n')
  631. test_imagefolder_extrashuffle()
  632. logger.info('test_imagefolder_extrashuffle Ended.\n')
  633. test_imagefolder_classindex()
  634. logger.info('test_imagefolder_classindex Ended.\n')
  635. test_imagefolder_negative_classindex()
  636. logger.info('test_imagefolder_negative_classindex Ended.\n')
  637. test_imagefolder_extensions()
  638. logger.info('test_imagefolder_extensions Ended.\n')
  639. test_imagefolder_decode()
  640. logger.info('test_imagefolder_decode Ended.\n')
  641. test_imagefolder_rename()
  642. logger.info('test_imagefolder_rename Ended.\n')
  643. test_imagefolder_zip()
  644. logger.info('test_imagefolder_zip Ended.\n')
  645. test_imagefolder_exception()
  646. logger.info('test_imagefolder_exception Ended.\n')