You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_paddeddataset.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. from io import BytesIO
  2. import copy
  3. import os
  4. import numpy as np
  5. import pytest
  6. import mindspore.dataset as ds
  7. from mindspore.mindrecord import FileWriter
  8. import mindspore.dataset.vision.c_transforms as V_C
  9. from PIL import Image
  10. FILES_NUM = 4
  11. CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
  12. CV_DIR_NAME = "../data/mindrecord/testImageNetData"
  13. def generator_5():
  14. for i in range(0, 5):
  15. yield (np.array([i]),)
  16. def generator_8():
  17. for i in range(5, 8):
  18. yield (np.array([i]),)
  19. def generator_10():
  20. for i in range(0, 10):
  21. yield (np.array([i]),)
  22. def generator_20():
  23. for i in range(10, 20):
  24. yield (np.array([i]),)
  25. def generator_30():
  26. for i in range(20, 30):
  27. yield (np.array([i]),)
  28. def test_TFRecord_Padded():
  29. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  30. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  31. result_list = [[159109, 2], [192607, 3], [179251, 4], [1, 5]]
  32. verify_list = []
  33. shard_num = 4
  34. for i in range(shard_num):
  35. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"],
  36. shuffle=False, shard_equal_rows=True)
  37. padded_samples = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  38. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  39. {'image': np.zeros(5, np.uint8)}]
  40. padded_ds = ds.PaddedDataset(padded_samples)
  41. concat_ds = data + padded_ds
  42. testsampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  43. concat_ds.use_sampler(testsampler)
  44. shard_list = []
  45. for item in concat_ds.create_dict_iterator(num_epochs=1, output_numpy=True):
  46. shard_list.append(len(item['image']))
  47. verify_list.append(shard_list)
  48. assert verify_list == result_list
  49. def test_GeneratorDataSet_Padded():
  50. result_list = []
  51. for i in range(10):
  52. tem_list = []
  53. tem_list.append(i)
  54. tem_list.append(10 + i)
  55. result_list.append(tem_list)
  56. verify_list = []
  57. data1 = ds.GeneratorDataset(generator_20, ["col1"])
  58. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  59. data3 = data2 + data1
  60. shard_num = 10
  61. for i in range(shard_num):
  62. distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  63. data3.use_sampler(distributed_sampler)
  64. tem_list = []
  65. for ele in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
  66. tem_list.append(ele['col1'][0])
  67. verify_list.append(tem_list)
  68. assert verify_list == result_list
  69. def test_Reapeat_afterPadded():
  70. result_list = [1, 3, 5, 7]
  71. verify_list = []
  72. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  73. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  74. {'image': np.zeros(5, np.uint8)}]
  75. data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
  76. {'image': np.zeros(8, np.uint8)}]
  77. ds1 = ds.PaddedDataset(data1)
  78. ds2 = ds.PaddedDataset(data2)
  79. ds3 = ds1 + ds2
  80. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
  81. ds3.use_sampler(testsampler)
  82. repeat_num = 2
  83. ds3 = ds3.repeat(repeat_num)
  84. for item in ds3.create_dict_iterator(num_epochs=1, output_numpy=True):
  85. verify_list.append(len(item['image']))
  86. assert verify_list == result_list * repeat_num
  87. def test_bath_afterPadded():
  88. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
  89. {'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
  90. {'image': np.zeros(1, np.uint8)}]
  91. data2 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
  92. {'image': np.zeros(1, np.uint8)}]
  93. ds1 = ds.PaddedDataset(data1)
  94. ds2 = ds.PaddedDataset(data2)
  95. ds3 = ds1 + ds2
  96. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
  97. ds3.use_sampler(testsampler)
  98. ds4 = ds3.batch(2)
  99. assert sum([1 for _ in ds4]) == 2
  100. def test_Unevenly_distributed():
  101. result_list = [[1, 4, 7], [2, 5, 8], [3, 6]]
  102. verify_list = []
  103. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  104. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  105. {'image': np.zeros(5, np.uint8)}]
  106. data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
  107. {'image': np.zeros(8, np.uint8)}]
  108. testsampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=1)
  109. ds1 = ds.PaddedDataset(data1)
  110. ds2 = ds.PaddedDataset(data2)
  111. ds3 = ds1 + ds2
  112. numShard = 3
  113. for i in range(numShard):
  114. tem_list = []
  115. testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
  116. ds3.use_sampler(testsampler)
  117. for item in ds3.create_dict_iterator(num_epochs=1, output_numpy=True):
  118. tem_list.append(len(item['image']))
  119. verify_list.append(tem_list)
  120. assert verify_list == result_list
  121. def test_three_datasets_connected():
  122. result_list = []
  123. for i in range(10):
  124. tem_list = []
  125. tem_list.append(i)
  126. tem_list.append(10 + i)
  127. tem_list.append(20 + i)
  128. result_list.append(tem_list)
  129. verify_list = []
  130. data1 = ds.GeneratorDataset(generator_10, ["col1"])
  131. data2 = ds.GeneratorDataset(generator_20, ["col1"])
  132. data3 = ds.GeneratorDataset(generator_30, ["col1"])
  133. data4 = data1 + data2 + data3
  134. shard_num = 10
  135. for i in range(shard_num):
  136. distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  137. data4.use_sampler(distributed_sampler)
  138. tem_list = []
  139. for ele in data4.create_dict_iterator(num_epochs=1, output_numpy=True):
  140. tem_list.append(ele['col1'][0])
  141. verify_list.append(tem_list)
  142. assert verify_list == result_list
  143. def test_raise_error():
  144. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  145. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  146. {'image': np.zeros(5, np.uint8)}]
  147. data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
  148. {'image': np.zeros(8, np.uint8)}]
  149. ds1 = ds.PaddedDataset(data1)
  150. ds4 = ds1.batch(2)
  151. ds2 = ds.PaddedDataset(data2)
  152. ds3 = ds4 + ds2
  153. with pytest.raises(TypeError) as excinfo:
  154. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
  155. ds3.use_sampler(testsampler)
  156. assert excinfo.type == 'TypeError'
  157. with pytest.raises(TypeError) as excinfo:
  158. otherSampler = ds.SequentialSampler()
  159. ds3.use_sampler(otherSampler)
  160. assert excinfo.type == 'TypeError'
  161. with pytest.raises(ValueError) as excinfo:
  162. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=True, num_samples=None)
  163. ds3.use_sampler(testsampler)
  164. assert excinfo.type == 'ValueError'
  165. with pytest.raises(ValueError) as excinfo:
  166. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
  167. ds3.use_sampler(testsampler)
  168. assert excinfo.type == 'ValueError'
  169. def test_imagefolder_error():
  170. DATA_DIR = "../data/dataset/testPK/data"
  171. data = ds.ImageFolderDataset(DATA_DIR, num_samples=14)
  172. data1 = [{'image': np.zeros(1, np.uint8), 'label': np.array(0, np.int32)},
  173. {'image': np.zeros(2, np.uint8), 'label': np.array(1, np.int32)},
  174. {'image': np.zeros(3, np.uint8), 'label': np.array(0, np.int32)},
  175. {'image': np.zeros(4, np.uint8), 'label': np.array(1, np.int32)},
  176. {'image': np.zeros(5, np.uint8), 'label': np.array(0, np.int32)},
  177. {'image': np.zeros(6, np.uint8), 'label': np.array(1, np.int32)}]
  178. data2 = ds.PaddedDataset(data1)
  179. data3 = data + data2
  180. with pytest.raises(ValueError) as excinfo:
  181. testsampler = ds.DistributedSampler(num_shards=5, shard_id=4, shuffle=False, num_samples=None)
  182. data3.use_sampler(testsampler)
  183. assert excinfo.type == 'ValueError'
  184. def test_imagefolder_padded():
  185. DATA_DIR = "../data/dataset/testPK/data"
  186. data = ds.ImageFolderDataset(DATA_DIR)
  187. data1 = [{'image': np.zeros(1, np.uint8), 'label': np.array(0, np.int32)},
  188. {'image': np.zeros(2, np.uint8), 'label': np.array(1, np.int32)},
  189. {'image': np.zeros(3, np.uint8), 'label': np.array(0, np.int32)},
  190. {'image': np.zeros(4, np.uint8), 'label': np.array(1, np.int32)},
  191. {'image': np.zeros(5, np.uint8), 'label': np.array(0, np.int32)},
  192. {'image': np.zeros(6, np.uint8), 'label': np.array(1, np.int32)}]
  193. data2 = ds.PaddedDataset(data1)
  194. data3 = data + data2
  195. testsampler = ds.DistributedSampler(num_shards=5, shard_id=4, shuffle=False, num_samples=None)
  196. data3.use_sampler(testsampler)
  197. assert sum([1 for _ in data3]) == 10
  198. verify_list = []
  199. for ele in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
  200. verify_list.append(len(ele['image']))
  201. assert verify_list[8] == 1
  202. assert verify_list[9] == 6
  203. def test_imagefolder_padded_with_decode():
  204. num_shards = 5
  205. count = 0
  206. for shard_id in range(num_shards):
  207. DATA_DIR = "../data/dataset/testPK/data"
  208. data = ds.ImageFolderDataset(DATA_DIR)
  209. white_io = BytesIO()
  210. Image.new('RGB', (224, 224), (255, 255, 255)).save(white_io, 'JPEG')
  211. padded_sample = {}
  212. padded_sample['image'] = np.array(bytearray(white_io.getvalue()), dtype='uint8')
  213. padded_sample['label'] = np.array(-1, np.int32)
  214. white_samples = [padded_sample, padded_sample, padded_sample, padded_sample]
  215. data2 = ds.PaddedDataset(white_samples)
  216. data3 = data + data2
  217. testsampler = ds.DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=None)
  218. data3.use_sampler(testsampler)
  219. data3 = data3.map(operations=V_C.Decode(), input_columns="image")
  220. shard_sample_count = 0
  221. for ele in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
  222. print("label: {}".format(ele['label']))
  223. count += 1
  224. shard_sample_count += 1
  225. assert shard_sample_count in (9, 10)
  226. assert count == 48
  227. def test_imagefolder_padded_with_decode_and_get_dataset_size():
  228. num_shards = 5
  229. count = 0
  230. for shard_id in range(num_shards):
  231. DATA_DIR = "../data/dataset/testPK/data"
  232. data = ds.ImageFolderDataset(DATA_DIR)
  233. white_io = BytesIO()
  234. Image.new('RGB', (224, 224), (255, 255, 255)).save(white_io, 'JPEG')
  235. padded_sample = {}
  236. padded_sample['image'] = np.array(bytearray(white_io.getvalue()), dtype='uint8')
  237. padded_sample['label'] = np.array(-1, np.int32)
  238. white_samples = [padded_sample, padded_sample, padded_sample, padded_sample]
  239. data2 = ds.PaddedDataset(white_samples)
  240. data3 = data + data2
  241. testsampler = ds.DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=None)
  242. data3.use_sampler(testsampler)
  243. shard_dataset_size = data3.get_dataset_size()
  244. data3 = data3.map(operations=V_C.Decode(), input_columns="image")
  245. shard_sample_count = 0
  246. for ele in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
  247. print("label: {}".format(ele['label']))
  248. count += 1
  249. shard_sample_count += 1
  250. assert shard_sample_count in (9, 10)
  251. assert shard_dataset_size == shard_sample_count
  252. assert count == 48
  253. def test_more_shard_padded():
  254. result_list = []
  255. for i in range(8):
  256. result_list.append(1)
  257. result_list.append(0)
  258. data1 = ds.GeneratorDataset(generator_5, ["col1"])
  259. data2 = ds.GeneratorDataset(generator_8, ["col1"])
  260. data3 = data1 + data2
  261. vertifyList = []
  262. numShard = 9
  263. for i in range(numShard):
  264. tem_list = []
  265. testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
  266. data3.use_sampler(testsampler)
  267. for item in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
  268. tem_list.append(item['col1'])
  269. vertifyList.append(tem_list)
  270. assert [len(ele) for ele in vertifyList] == result_list
  271. vertifyList1 = []
  272. result_list1 = []
  273. for i in range(8):
  274. result_list1.append([i + 1])
  275. result_list1.append([])
  276. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  277. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  278. {'image': np.zeros(5, np.uint8)}]
  279. data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
  280. {'image': np.zeros(8, np.uint8)}]
  281. ds1 = ds.PaddedDataset(data1)
  282. ds2 = ds.PaddedDataset(data2)
  283. ds3 = ds1 + ds2
  284. for i in range(numShard):
  285. tem_list = []
  286. testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
  287. ds3.use_sampler(testsampler)
  288. for item in ds3.create_dict_iterator(num_epochs=1, output_numpy=True):
  289. tem_list.append(len(item['image']))
  290. vertifyList1.append(tem_list)
  291. assert vertifyList1 == result_list1
  292. def get_data(dir_name):
  293. """
  294. usage: get data from imagenet dataset
  295. params:
  296. dir_name: directory containing folder images and annotation information
  297. """
  298. if not os.path.isdir(dir_name):
  299. raise IOError("Directory {} not exists".format(dir_name))
  300. img_dir = os.path.join(dir_name, "images")
  301. ann_file = os.path.join(dir_name, "annotation.txt")
  302. with open(ann_file, "r") as file_reader:
  303. lines = file_reader.readlines()
  304. data_list = []
  305. for i, line in enumerate(lines):
  306. try:
  307. filename, label = line.split(",")
  308. label = label.strip("\n")
  309. with open(os.path.join(img_dir, filename), "rb") as file_reader:
  310. img = file_reader.read()
  311. data_json = {"id": i,
  312. "file_name": filename,
  313. "data": img,
  314. "label": int(label)}
  315. data_list.append(data_json)
  316. except FileNotFoundError:
  317. continue
  318. return data_list
  319. @pytest.fixture(name="remove_mindrecord_file")
  320. def add_and_remove_cv_file():
  321. """add/remove cv file"""
  322. paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
  323. for x in range(FILES_NUM)]
  324. try:
  325. for x in paths:
  326. if os.path.exists("{}".format(x)):
  327. os.remove("{}".format(x))
  328. if os.path.exists("{}.db".format(x)):
  329. os.remove("{}.db".format(x))
  330. writer = FileWriter(CV_FILE_NAME, FILES_NUM)
  331. data = get_data(CV_DIR_NAME)
  332. cv_schema_json = {"id": {"type": "int32"},
  333. "file_name": {"type": "string"},
  334. "label": {"type": "int32"},
  335. "data": {"type": "bytes"}}
  336. writer.add_schema(cv_schema_json, "img_schema")
  337. writer.add_index(["file_name", "label"])
  338. writer.write_raw_data(data)
  339. writer.commit()
  340. yield "yield_cv_data"
  341. except Exception as error:
  342. for x in paths:
  343. os.remove("{}".format(x))
  344. os.remove("{}.db".format(x))
  345. raise error
  346. else:
  347. for x in paths:
  348. os.remove("{}".format(x))
  349. os.remove("{}.db".format(x))
  350. def test_Mindrecord_Padded(remove_mindrecord_file):
  351. result_list = []
  352. verify_list = [[1, 2], [3, 4], [5, 11], [6, 12], [7, 13], [8, 14], [9], [10]]
  353. num_readers = 4
  354. data_set = ds.MindDataset(CV_FILE_NAME + "0", ['file_name'], num_readers, shuffle=False)
  355. data1 = [{'file_name': np.array(b'image_00011.jpg', dtype='|S15')},
  356. {'file_name': np.array(b'image_00012.jpg', dtype='|S15')},
  357. {'file_name': np.array(b'image_00013.jpg', dtype='|S15')},
  358. {'file_name': np.array(b'image_00014.jpg', dtype='|S15')}]
  359. ds1 = ds.PaddedDataset(data1)
  360. ds2 = data_set + ds1
  361. shard_num = 8
  362. for i in range(shard_num):
  363. testsampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  364. ds2.use_sampler(testsampler)
  365. tem_list = []
  366. for ele in ds2.create_dict_iterator(num_epochs=1, output_numpy=True):
  367. tem_list.append(int(ele['file_name'].tostring().decode().lstrip('image_').rstrip('.jpg')))
  368. result_list.append(tem_list)
  369. assert result_list == verify_list
  370. def test_clue_padded_and_skip_with_0_samples():
  371. """
  372. Test num_samples param of CLUE dataset
  373. """
  374. TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
  375. data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
  376. count = 0
  377. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  378. count += 1
  379. assert count == 3
  380. data_copy1 = copy.deepcopy(data)
  381. sample = {"label": np.array(1, np.string_),
  382. "sentence1": np.array(1, np.string_),
  383. "sentence2": np.array(1, np.string_)}
  384. samples = [sample]
  385. padded_ds = ds.PaddedDataset(samples)
  386. dataset = data + padded_ds
  387. testsampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None)
  388. dataset.use_sampler(testsampler)
  389. assert dataset.get_dataset_size() == 2
  390. count = 0
  391. for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  392. count += 1
  393. assert count == 2
  394. dataset = dataset.skip(count=2) # dataset2 has none samples
  395. count = 0
  396. for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  397. count += 1
  398. assert count == 0
  399. with pytest.raises(ValueError, match="There is no samples in the "):
  400. dataset = dataset.concat(data_copy1)
  401. count = 0
  402. for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  403. count += 1
  404. assert count == 2
  405. def test_celeba_padded():
  406. data = ds.CelebADataset("../data/dataset/testCelebAData/")
  407. padded_samples = [{'image': np.zeros(1, np.uint8), 'attr': np.zeros(1, np.uint32)}]
  408. padded_ds = ds.PaddedDataset(padded_samples)
  409. data = data + padded_ds
  410. dis_sampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None)
  411. data.use_sampler(dis_sampler)
  412. data = data.repeat(2)
  413. count = 0
  414. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  415. count = count + 1
  416. assert count == 4
  417. if __name__ == '__main__':
  418. test_TFRecord_Padded()
  419. test_GeneratorDataSet_Padded()
  420. test_Reapeat_afterPadded()
  421. test_bath_afterPadded()
  422. test_Unevenly_distributed()
  423. test_three_datasets_connected()
  424. test_raise_error()
  425. test_imagefolden_padded()
  426. test_more_shard_padded()
  427. test_Mindrecord_Padded(add_and_remove_cv_file)