You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_minddataset_sampler.py 36 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This is the test module for mindrecord
  17. """
  18. import os
  19. import pytest
  20. import numpy as np
  21. import mindspore.dataset as ds
  22. from mindspore import log as logger
  23. from mindspore.dataset.text import to_str
  24. from mindspore.mindrecord import FileWriter
  25. FILES_NUM = 4
  26. CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
  27. CV_DIR_NAME = "../data/mindrecord/testImageNetData"
  28. @pytest.fixture
  29. def add_and_remove_cv_file():
  30. """add/remove cv file"""
  31. paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
  32. for x in range(FILES_NUM)]
  33. try:
  34. for x in paths:
  35. if os.path.exists("{}".format(x)):
  36. os.remove("{}".format(x))
  37. if os.path.exists("{}.db".format(x)):
  38. os.remove("{}.db".format(x))
  39. writer = FileWriter(CV_FILE_NAME, FILES_NUM)
  40. data = get_data(CV_DIR_NAME, True)
  41. cv_schema_json = {"id": {"type": "int32"},
  42. "file_name": {"type": "string"},
  43. "label": {"type": "int32"},
  44. "data": {"type": "bytes"}}
  45. writer.add_schema(cv_schema_json, "img_schema")
  46. writer.add_index(["file_name", "label"])
  47. writer.write_raw_data(data)
  48. writer.commit()
  49. yield "yield_cv_data"
  50. except Exception as error:
  51. for x in paths:
  52. os.remove("{}".format(x))
  53. os.remove("{}.db".format(x))
  54. raise error
  55. else:
  56. for x in paths:
  57. os.remove("{}".format(x))
  58. os.remove("{}.db".format(x))
  59. def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file):
  60. """tutorial for cv minderdataset."""
  61. num_readers = 4
  62. sampler = ds.PKSampler(2)
  63. data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers,
  64. sampler=sampler)
  65. assert data_set.get_dataset_size() == 6
  66. num_iter = 0
  67. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  68. logger.info(
  69. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  70. logger.info("-------------- item[file_name]: \
  71. {}------------------------".format(to_str(item["file_name"])))
  72. logger.info(
  73. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  74. num_iter += 1
  75. def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
  76. """tutorial for cv minderdataset."""
  77. columns_list = ["data", "file_name", "label"]
  78. num_readers = 4
  79. sampler = ds.PKSampler(2)
  80. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  81. sampler=sampler)
  82. assert data_set.get_dataset_size() == 6
  83. num_iter = 0
  84. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  85. logger.info(
  86. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  87. logger.info("-------------- item[data]: \
  88. {}------------------------".format(item["data"][:10]))
  89. logger.info("-------------- item[file_name]: \
  90. {}------------------------".format(to_str(item["file_name"])))
  91. logger.info(
  92. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  93. num_iter += 1
  94. def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file):
  95. """tutorial for cv minderdataset."""
  96. columns_list = ["data", "file_name", "label"]
  97. num_readers = 4
  98. sampler = ds.PKSampler(3, None, True)
  99. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  100. sampler=sampler)
  101. assert data_set.get_dataset_size() == 9
  102. num_iter = 0
  103. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  104. logger.info(
  105. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  106. logger.info("-------------- item[file_name]: \
  107. {}------------------------".format(to_str(item["file_name"])))
  108. logger.info(
  109. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  110. num_iter += 1
  111. assert num_iter == 9
  112. def test_cv_minddataset_pk_sample_shuffle_1(add_and_remove_cv_file):
  113. """tutorial for cv minderdataset."""
  114. columns_list = ["data", "file_name", "label"]
  115. num_readers = 4
  116. sampler = ds.PKSampler(3, None, True, 'label', 5)
  117. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  118. sampler=sampler)
  119. assert data_set.get_dataset_size() == 5
  120. num_iter = 0
  121. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  122. logger.info(
  123. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  124. logger.info("-------------- item[file_name]: \
  125. {}------------------------".format(to_str(item["file_name"])))
  126. logger.info(
  127. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  128. num_iter += 1
  129. assert num_iter == 5
  130. def test_cv_minddataset_pk_sample_shuffle_2(add_and_remove_cv_file):
  131. """tutorial for cv minderdataset."""
  132. columns_list = ["data", "file_name", "label"]
  133. num_readers = 4
  134. sampler = ds.PKSampler(3, None, True, 'label', 10)
  135. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  136. sampler=sampler)
  137. assert data_set.get_dataset_size() == 9
  138. num_iter = 0
  139. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  140. logger.info(
  141. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  142. logger.info("-------------- item[file_name]: \
  143. {}------------------------".format(to_str(item["file_name"])))
  144. logger.info(
  145. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  146. num_iter += 1
  147. assert num_iter == 9
  148. def test_cv_minddataset_pk_sample_out_of_range_0(add_and_remove_cv_file):
  149. """tutorial for cv minderdataset."""
  150. columns_list = ["data", "file_name", "label"]
  151. num_readers = 4
  152. sampler = ds.PKSampler(5, None, True)
  153. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  154. sampler=sampler)
  155. assert data_set.get_dataset_size() == 15
  156. num_iter = 0
  157. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  158. logger.info(
  159. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  160. logger.info("-------------- item[file_name]: \
  161. {}------------------------".format(to_str(item["file_name"])))
  162. logger.info(
  163. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  164. num_iter += 1
  165. assert num_iter == 15
  166. def test_cv_minddataset_pk_sample_out_of_range_1(add_and_remove_cv_file):
  167. """tutorial for cv minderdataset."""
  168. columns_list = ["data", "file_name", "label"]
  169. num_readers = 4
  170. sampler = ds.PKSampler(5, None, True, 'label', 20)
  171. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  172. sampler=sampler)
  173. assert data_set.get_dataset_size() == 15
  174. num_iter = 0
  175. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  176. logger.info(
  177. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  178. logger.info("-------------- item[file_name]: \
  179. {}------------------------".format(to_str(item["file_name"])))
  180. logger.info(
  181. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  182. num_iter += 1
  183. assert num_iter == 15
  184. def test_cv_minddataset_pk_sample_out_of_range_2(add_and_remove_cv_file):
  185. """tutorial for cv minderdataset."""
  186. columns_list = ["data", "file_name", "label"]
  187. num_readers = 4
  188. sampler = ds.PKSampler(5, None, True, 'label', 10)
  189. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  190. sampler=sampler)
  191. assert data_set.get_dataset_size() == 10
  192. num_iter = 0
  193. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  194. logger.info(
  195. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  196. logger.info("-------------- item[file_name]: \
  197. {}------------------------".format(to_str(item["file_name"])))
  198. logger.info(
  199. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  200. num_iter += 1
  201. assert num_iter == 10
  202. def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file):
  203. """tutorial for cv minderdataset."""
  204. columns_list = ["data", "file_name", "label"]
  205. num_readers = 4
  206. indices = [1, 2, 3, 5, 7]
  207. samplers = (ds.SubsetRandomSampler(indices), ds.SubsetSampler(indices))
  208. for sampler in samplers:
  209. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  210. sampler=sampler)
  211. assert data_set.get_dataset_size() == 5
  212. num_iter = 0
  213. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  214. logger.info(
  215. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  216. logger.info(
  217. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  218. logger.info(
  219. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  220. logger.info(
  221. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  222. num_iter += 1
  223. assert num_iter == 5
  224. def test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file):
  225. """tutorial for cv minderdataset."""
  226. columns_list = ["data", "file_name", "label"]
  227. num_readers = 4
  228. indices = [1, 2, 2, 5, 7, 9]
  229. samplers = ds.SubsetRandomSampler(indices), ds.SubsetSampler(indices)
  230. for sampler in samplers:
  231. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  232. sampler=sampler)
  233. assert data_set.get_dataset_size() == 6
  234. num_iter = 0
  235. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  236. logger.info(
  237. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  238. logger.info(
  239. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  240. logger.info(
  241. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  242. logger.info(
  243. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  244. num_iter += 1
  245. assert num_iter == 6
  246. def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file):
  247. """tutorial for cv minderdataset."""
  248. columns_list = ["data", "file_name", "label"]
  249. num_readers = 4
  250. indices = []
  251. samplers = ds.SubsetRandomSampler(indices), ds.SubsetSampler(indices)
  252. for sampler in samplers:
  253. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  254. sampler=sampler)
  255. assert data_set.get_dataset_size() == 0
  256. num_iter = 0
  257. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  258. logger.info(
  259. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  260. logger.info(
  261. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  262. logger.info(
  263. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  264. logger.info(
  265. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  266. num_iter += 1
  267. assert num_iter == 0
  268. def test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file):
  269. """tutorial for cv minderdataset."""
  270. columns_list = ["data", "file_name", "label"]
  271. num_readers = 4
  272. indices = [1, 2, 4, 11, 13]
  273. samplers = ds.SubsetRandomSampler(indices), ds.SubsetSampler(indices)
  274. for sampler in samplers:
  275. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  276. sampler=sampler)
  277. assert data_set.get_dataset_size() == 5
  278. num_iter = 0
  279. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  280. logger.info(
  281. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  282. logger.info(
  283. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  284. logger.info(
  285. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  286. logger.info(
  287. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  288. num_iter += 1
  289. assert num_iter == 5
  290. def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file):
  291. columns_list = ["data", "file_name", "label"]
  292. num_readers = 4
  293. indices = [1, 2, 4, -1, -2]
  294. samplers = ds.SubsetRandomSampler(indices), ds.SubsetSampler(indices)
  295. for sampler in samplers:
  296. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  297. sampler=sampler)
  298. assert data_set.get_dataset_size() == 5
  299. num_iter = 0
  300. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  301. logger.info(
  302. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  303. logger.info(
  304. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  305. logger.info(
  306. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  307. logger.info(
  308. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  309. num_iter += 1
  310. assert num_iter == 5
  311. def test_cv_minddataset_random_sampler_basic(add_and_remove_cv_file):
  312. data = get_data(CV_DIR_NAME, True)
  313. columns_list = ["data", "file_name", "label"]
  314. num_readers = 4
  315. sampler = ds.RandomSampler()
  316. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  317. sampler=sampler)
  318. assert data_set.get_dataset_size() == 10
  319. num_iter = 0
  320. new_dataset = []
  321. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  322. logger.info(
  323. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  324. logger.info(
  325. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  326. logger.info(
  327. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  328. logger.info(
  329. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  330. num_iter += 1
  331. new_dataset.append(item['file_name'])
  332. assert num_iter == 10
  333. assert new_dataset != [x['file_name'] for x in data]
  334. def test_cv_minddataset_random_sampler_repeat(add_and_remove_cv_file):
  335. columns_list = ["data", "file_name", "label"]
  336. num_readers = 4
  337. sampler = ds.RandomSampler()
  338. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  339. sampler=sampler)
  340. assert data_set.get_dataset_size() == 10
  341. ds1 = data_set.repeat(3)
  342. num_iter = 0
  343. epoch1_dataset = []
  344. epoch2_dataset = []
  345. epoch3_dataset = []
  346. for item in ds1.create_dict_iterator(num_epochs=1, output_numpy=True):
  347. logger.info(
  348. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  349. logger.info(
  350. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  351. logger.info(
  352. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  353. logger.info(
  354. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  355. num_iter += 1
  356. if num_iter <= 10:
  357. epoch1_dataset.append(item['file_name'])
  358. elif num_iter <= 20:
  359. epoch2_dataset.append(item['file_name'])
  360. else:
  361. epoch3_dataset.append(item['file_name'])
  362. assert num_iter == 30
  363. assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset)
  364. assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset)
  365. assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset)
  366. def test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file):
  367. columns_list = ["data", "file_name", "label"]
  368. num_readers = 4
  369. sampler = ds.RandomSampler(replacement=True, num_samples=5)
  370. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  371. sampler=sampler)
  372. assert data_set.get_dataset_size() == 5
  373. num_iter = 0
  374. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  375. logger.info(
  376. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  377. logger.info(
  378. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  379. logger.info(
  380. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  381. logger.info(
  382. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  383. num_iter += 1
  384. assert num_iter == 5
  385. def test_cv_minddataset_random_sampler_replacement_false_1(add_and_remove_cv_file):
  386. columns_list = ["data", "file_name", "label"]
  387. num_readers = 4
  388. sampler = ds.RandomSampler(replacement=False, num_samples=2)
  389. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  390. sampler=sampler)
  391. assert data_set.get_dataset_size() == 2
  392. num_iter = 0
  393. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  394. logger.info(
  395. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  396. logger.info(
  397. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  398. logger.info(
  399. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  400. logger.info(
  401. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  402. num_iter += 1
  403. assert num_iter == 2
  404. def test_cv_minddataset_random_sampler_replacement_false_2(add_and_remove_cv_file):
  405. columns_list = ["data", "file_name", "label"]
  406. num_readers = 4
  407. sampler = ds.RandomSampler(replacement=False, num_samples=20)
  408. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  409. sampler=sampler)
  410. assert data_set.get_dataset_size() == 10
  411. num_iter = 0
  412. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  413. logger.info(
  414. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  415. logger.info(
  416. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  417. logger.info(
  418. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  419. logger.info(
  420. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  421. num_iter += 1
  422. assert num_iter == 10
  423. def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file):
  424. data = get_data(CV_DIR_NAME, True)
  425. columns_list = ["data", "file_name", "label"]
  426. num_readers = 4
  427. sampler = ds.SequentialSampler(1, 4)
  428. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  429. sampler=sampler)
  430. assert data_set.get_dataset_size() == 4
  431. num_iter = 0
  432. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  433. logger.info(
  434. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  435. logger.info(
  436. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  437. logger.info(
  438. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  439. logger.info(
  440. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  441. assert item['file_name'] == np.array(
  442. data[num_iter + 1]['file_name'], dtype='S')
  443. num_iter += 1
  444. assert num_iter == 4
  445. def test_cv_minddataset_sequential_sampler_offeset(add_and_remove_cv_file):
  446. data = get_data(CV_DIR_NAME, True)
  447. columns_list = ["data", "file_name", "label"]
  448. num_readers = 4
  449. sampler = ds.SequentialSampler(2, 10)
  450. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  451. sampler=sampler)
  452. dataset_size = data_set.get_dataset_size()
  453. assert dataset_size == 10
  454. num_iter = 0
  455. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  456. logger.info(
  457. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  458. logger.info(
  459. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  460. logger.info(
  461. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  462. logger.info(
  463. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  464. assert item['file_name'] == np.array(
  465. data[(num_iter + 2) % dataset_size]['file_name'], dtype='S')
  466. num_iter += 1
  467. assert num_iter == 10
  468. def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file):
  469. data = get_data(CV_DIR_NAME, True)
  470. columns_list = ["data", "file_name", "label"]
  471. num_readers = 4
  472. sampler = ds.SequentialSampler(2, 20)
  473. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  474. sampler=sampler)
  475. dataset_size = data_set.get_dataset_size()
  476. assert dataset_size == 10
  477. num_iter = 0
  478. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  479. logger.info(
  480. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  481. logger.info(
  482. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  483. logger.info(
  484. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  485. logger.info(
  486. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  487. assert item['file_name'] == np.array(
  488. data[(num_iter + 2) % dataset_size]['file_name'], dtype='S')
  489. num_iter += 1
  490. assert num_iter == 10
  491. def test_cv_minddataset_split_basic(add_and_remove_cv_file):
  492. data = get_data(CV_DIR_NAME, True)
  493. columns_list = ["data", "file_name", "label"]
  494. num_readers = 4
  495. d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
  496. num_readers, shuffle=False)
  497. d1, d2 = d.split([8, 2], randomize=False)
  498. assert d.get_dataset_size() == 10
  499. assert d1.get_dataset_size() == 8
  500. assert d2.get_dataset_size() == 2
  501. num_iter = 0
  502. for item in d1.create_dict_iterator(num_epochs=1, output_numpy=True):
  503. logger.info(
  504. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  505. logger.info(
  506. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  507. logger.info(
  508. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  509. assert item['file_name'] == np.array(data[num_iter]['file_name'],
  510. dtype='S')
  511. num_iter += 1
  512. assert num_iter == 8
  513. num_iter = 0
  514. for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
  515. logger.info(
  516. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  517. logger.info(
  518. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  519. logger.info(
  520. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  521. assert item['file_name'] == np.array(data[num_iter + 8]['file_name'],
  522. dtype='S')
  523. num_iter += 1
  524. assert num_iter == 2
  525. def test_cv_minddataset_split_exact_percent(add_and_remove_cv_file):
  526. data = get_data(CV_DIR_NAME, True)
  527. columns_list = ["data", "file_name", "label"]
  528. num_readers = 4
  529. d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
  530. num_readers, shuffle=False)
  531. d1, d2 = d.split([0.8, 0.2], randomize=False)
  532. assert d.get_dataset_size() == 10
  533. assert d1.get_dataset_size() == 8
  534. assert d2.get_dataset_size() == 2
  535. num_iter = 0
  536. for item in d1.create_dict_iterator(num_epochs=1, output_numpy=True):
  537. logger.info(
  538. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  539. logger.info(
  540. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  541. logger.info(
  542. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  543. assert item['file_name'] == np.array(
  544. data[num_iter]['file_name'], dtype='S')
  545. num_iter += 1
  546. assert num_iter == 8
  547. num_iter = 0
  548. for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
  549. logger.info(
  550. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  551. logger.info(
  552. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  553. logger.info(
  554. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  555. assert item['file_name'] == np.array(data[num_iter + 8]['file_name'],
  556. dtype='S')
  557. num_iter += 1
  558. assert num_iter == 2
  559. def test_cv_minddataset_split_fuzzy_percent(add_and_remove_cv_file):
  560. data = get_data(CV_DIR_NAME, True)
  561. columns_list = ["data", "file_name", "label"]
  562. num_readers = 4
  563. d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
  564. num_readers, shuffle=False)
  565. d1, d2 = d.split([0.41, 0.59], randomize=False)
  566. assert d.get_dataset_size() == 10
  567. assert d1.get_dataset_size() == 4
  568. assert d2.get_dataset_size() == 6
  569. num_iter = 0
  570. for item in d1.create_dict_iterator(num_epochs=1, output_numpy=True):
  571. logger.info(
  572. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  573. logger.info(
  574. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  575. logger.info(
  576. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  577. assert item['file_name'] == np.array(
  578. data[num_iter]['file_name'], dtype='S')
  579. num_iter += 1
  580. assert num_iter == 4
  581. num_iter = 0
  582. for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
  583. logger.info(
  584. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  585. logger.info(
  586. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  587. logger.info(
  588. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  589. assert item['file_name'] == np.array(data[num_iter + 4]['file_name'],
  590. dtype='S')
  591. num_iter += 1
  592. assert num_iter == 6
  593. def test_cv_minddataset_split_deterministic(add_and_remove_cv_file):
  594. columns_list = ["data", "file_name", "label"]
  595. num_readers = 4
  596. d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
  597. num_readers, shuffle=False)
  598. # should set seed to avoid data overlap
  599. ds.config.set_seed(111)
  600. d1, d2 = d.split([0.8, 0.2])
  601. assert d.get_dataset_size() == 10
  602. assert d1.get_dataset_size() == 8
  603. assert d2.get_dataset_size() == 2
  604. d1_dataset = []
  605. d2_dataset = []
  606. num_iter = 0
  607. for item in d1.create_dict_iterator(num_epochs=1, output_numpy=True):
  608. logger.info(
  609. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  610. logger.info(
  611. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  612. logger.info(
  613. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  614. d1_dataset.append(item['file_name'])
  615. num_iter += 1
  616. assert num_iter == 8
  617. num_iter = 0
  618. for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
  619. logger.info(
  620. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  621. logger.info(
  622. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  623. logger.info(
  624. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  625. d2_dataset.append(item['file_name'])
  626. num_iter += 1
  627. assert num_iter == 2
  628. inter_dataset = [x for x in d1_dataset if x in d2_dataset]
  629. assert inter_dataset == [] # intersection of d1 and d2
  630. def test_cv_minddataset_split_sharding(add_and_remove_cv_file):
  631. data = get_data(CV_DIR_NAME, True)
  632. columns_list = ["data", "file_name", "label"]
  633. num_readers = 4
  634. d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
  635. num_readers, shuffle=False)
  636. # should set seed to avoid data overlap
  637. ds.config.set_seed(111)
  638. d1, d2 = d.split([0.8, 0.2])
  639. assert d.get_dataset_size() == 10
  640. assert d1.get_dataset_size() == 8
  641. assert d2.get_dataset_size() == 2
  642. distributed_sampler = ds.DistributedSampler(2, 0)
  643. d1.use_sampler(distributed_sampler)
  644. assert d1.get_dataset_size() == 4
  645. num_iter = 0
  646. d1_shard1 = []
  647. for item in d1.create_dict_iterator(num_epochs=1, output_numpy=True):
  648. logger.info(
  649. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  650. logger.info(
  651. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  652. logger.info(
  653. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  654. num_iter += 1
  655. d1_shard1.append(item['file_name'])
  656. assert num_iter == 4
  657. assert d1_shard1 != [x['file_name'] for x in data[0:4]]
  658. distributed_sampler = ds.DistributedSampler(2, 1)
  659. d1.use_sampler(distributed_sampler)
  660. assert d1.get_dataset_size() == 4
  661. d1s = d1.repeat(3)
  662. epoch1_dataset = []
  663. epoch2_dataset = []
  664. epoch3_dataset = []
  665. num_iter = 0
  666. for item in d1s.create_dict_iterator(num_epochs=1, output_numpy=True):
  667. logger.info(
  668. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  669. logger.info(
  670. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  671. logger.info(
  672. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  673. num_iter += 1
  674. if num_iter <= 4:
  675. epoch1_dataset.append(item['file_name'])
  676. elif num_iter <= 8:
  677. epoch2_dataset.append(item['file_name'])
  678. else:
  679. epoch3_dataset.append(item['file_name'])
  680. assert len(epoch1_dataset) == 4
  681. assert len(epoch2_dataset) == 4
  682. assert len(epoch3_dataset) == 4
  683. inter_dataset = [x for x in d1_shard1 if x in epoch1_dataset]
  684. assert inter_dataset == [] # intersection of d1's shard1 and d1's shard2
  685. assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset)
  686. assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset)
  687. assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset)
  688. epoch1_dataset.sort()
  689. epoch2_dataset.sort()
  690. epoch3_dataset.sort()
  691. assert epoch1_dataset != epoch2_dataset
  692. assert epoch2_dataset != epoch3_dataset
  693. assert epoch3_dataset != epoch1_dataset
  694. def get_data(dir_name, sampler=False):
  695. """
  696. usage: get data from imagenet dataset
  697. params:
  698. dir_name: directory containing folder images and annotation information
  699. """
  700. if not os.path.isdir(dir_name):
  701. raise IOError("Directory {} not exists".format(dir_name))
  702. img_dir = os.path.join(dir_name, "images")
  703. if sampler:
  704. ann_file = os.path.join(dir_name, "annotation_sampler.txt")
  705. else:
  706. ann_file = os.path.join(dir_name, "annotation.txt")
  707. with open(ann_file, "r") as file_reader:
  708. lines = file_reader.readlines()
  709. data_list = []
  710. for i, line in enumerate(lines):
  711. try:
  712. filename, label = line.split(",")
  713. label = label.strip("\n")
  714. with open(os.path.join(img_dir, filename), "rb") as file_reader:
  715. img = file_reader.read()
  716. data_json = {"id": i,
  717. "file_name": filename,
  718. "data": img,
  719. "label": int(label)}
  720. data_list.append(data_json)
  721. except FileNotFoundError:
  722. continue
  723. return data_list
  724. if __name__ == '__main__':
  725. test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file)
  726. test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file)
  727. test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file)
  728. test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file)
  729. test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file)
  730. test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file)
  731. test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file)
  732. test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file)
  733. test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file)
  734. test_cv_minddataset_random_sampler_basic(add_and_remove_cv_file)
  735. test_cv_minddataset_random_sampler_repeat(add_and_remove_cv_file)
  736. test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file)
  737. test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file)
  738. test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file)
  739. test_cv_minddataset_split_basic(add_and_remove_cv_file)
  740. test_cv_minddataset_split_exact_percent(add_and_remove_cv_file)
  741. test_cv_minddataset_split_fuzzy_percent(add_and_remove_cv_file)
  742. test_cv_minddataset_split_deterministic(add_and_remove_cv_file)
  743. test_cv_minddataset_split_sharding(add_and_remove_cv_file)