You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_paddeddataset.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. from io import BytesIO
  2. import os
  3. import numpy as np
  4. import pytest
  5. import mindspore.dataset as ds
  6. from mindspore.mindrecord import FileWriter
  7. import mindspore.dataset.transforms.vision.c_transforms as V_C
  8. from PIL import Image
  9. FILES_NUM = 4
  10. CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
  11. CV_DIR_NAME = "../data/mindrecord/testImageNetData"
  12. def generator_5():
  13. for i in range(0, 5):
  14. yield (np.array([i]),)
  15. def generator_8():
  16. for i in range(5, 8):
  17. yield (np.array([i]),)
  18. def generator_10():
  19. for i in range(0, 10):
  20. yield (np.array([i]),)
  21. def generator_20():
  22. for i in range(10, 20):
  23. yield (np.array([i]),)
  24. def generator_30():
  25. for i in range(20, 30):
  26. yield (np.array([i]),)
  27. def test_TFRecord_Padded():
  28. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  29. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  30. result_list = [[159109, 2], [192607, 3], [179251, 4], [1, 5]]
  31. verify_list = []
  32. shard_num = 4
  33. for i in range(shard_num):
  34. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"],
  35. shuffle=False, shard_equal_rows=True)
  36. padded_samples = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  37. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  38. {'image': np.zeros(5, np.uint8)}]
  39. padded_ds = ds.PaddedDataset(padded_samples)
  40. concat_ds = data + padded_ds
  41. testsampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  42. concat_ds.use_sampler(testsampler)
  43. shard_list = []
  44. for item in concat_ds.create_dict_iterator():
  45. shard_list.append(len(item['image']))
  46. verify_list.append(shard_list)
  47. assert verify_list == result_list
  48. def test_GeneratorDataSet_Padded():
  49. result_list = []
  50. for i in range(10):
  51. tem_list = []
  52. tem_list.append(i)
  53. tem_list.append(10+i)
  54. result_list.append(tem_list)
  55. verify_list = []
  56. data1 = ds.GeneratorDataset(generator_20, ["col1"])
  57. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  58. data3 = data2 + data1
  59. shard_num = 10
  60. for i in range(shard_num):
  61. distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  62. data3.use_sampler(distributed_sampler)
  63. tem_list = []
  64. for ele in data3.create_dict_iterator():
  65. tem_list.append(ele['col1'][0])
  66. verify_list.append(tem_list)
  67. assert verify_list == result_list
  68. def test_Reapeat_afterPadded():
  69. result_list = [1, 3, 5, 7]
  70. verify_list = []
  71. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  72. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  73. {'image': np.zeros(5, np.uint8)}]
  74. data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
  75. {'image': np.zeros(8, np.uint8)}]
  76. ds1 = ds.PaddedDataset(data1)
  77. ds2 = ds.PaddedDataset(data2)
  78. ds3 = ds1 + ds2
  79. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
  80. ds3.use_sampler(testsampler)
  81. repeat_num = 2
  82. ds3 = ds3.repeat(repeat_num)
  83. for item in ds3.create_dict_iterator():
  84. verify_list.append(len(item['image']))
  85. assert verify_list == result_list * repeat_num
  86. def test_bath_afterPadded():
  87. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
  88. {'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
  89. {'image': np.zeros(1, np.uint8)}]
  90. data2 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
  91. {'image': np.zeros(1, np.uint8)}]
  92. ds1 = ds.PaddedDataset(data1)
  93. ds2 = ds.PaddedDataset(data2)
  94. ds3 = ds1 + ds2
  95. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
  96. ds3.use_sampler(testsampler)
  97. ds4 = ds3.batch(2)
  98. assert sum([1 for _ in ds4]) == 2
  99. def test_Unevenly_distributed():
  100. result_list = [[1, 4, 7], [2, 5, 8], [3, 6]]
  101. verify_list = []
  102. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  103. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  104. {'image': np.zeros(5, np.uint8)}]
  105. data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
  106. {'image': np.zeros(8, np.uint8)}]
  107. testsampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=1)
  108. ds1 = ds.PaddedDataset(data1)
  109. ds2 = ds.PaddedDataset(data2)
  110. ds3 = ds1 + ds2
  111. numShard = 3
  112. for i in range(numShard):
  113. tem_list = []
  114. testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
  115. ds3.use_sampler(testsampler)
  116. for item in ds3.create_dict_iterator():
  117. tem_list.append(len(item['image']))
  118. verify_list.append(tem_list)
  119. assert verify_list == result_list
  120. def test_three_datasets_connected():
  121. result_list = []
  122. for i in range(10):
  123. tem_list = []
  124. tem_list.append(i)
  125. tem_list.append(10 + i)
  126. tem_list.append(20 + i)
  127. result_list.append(tem_list)
  128. verify_list = []
  129. data1 = ds.GeneratorDataset(generator_10, ["col1"])
  130. data2 = ds.GeneratorDataset(generator_20, ["col1"])
  131. data3 = ds.GeneratorDataset(generator_30, ["col1"])
  132. data4 = data1 + data2 + data3
  133. shard_num = 10
  134. for i in range(shard_num):
  135. distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  136. data4.use_sampler(distributed_sampler)
  137. tem_list = []
  138. for ele in data4.create_dict_iterator():
  139. tem_list.append(ele['col1'][0])
  140. verify_list.append(tem_list)
  141. assert verify_list == result_list
  142. def test_raise_error():
  143. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  144. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  145. {'image': np.zeros(5, np.uint8)}]
  146. data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
  147. {'image': np.zeros(8, np.uint8)}]
  148. ds1 = ds.PaddedDataset(data1)
  149. ds4 = ds1.batch(2)
  150. ds2 = ds.PaddedDataset(data2)
  151. ds3 = ds4 + ds2
  152. with pytest.raises(TypeError) as excinfo:
  153. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
  154. ds3.use_sampler(testsampler)
  155. assert excinfo.type == 'TypeError'
  156. with pytest.raises(TypeError) as excinfo:
  157. otherSampler = ds.SequentialSampler()
  158. ds3.use_sampler(otherSampler)
  159. assert excinfo.type == 'TypeError'
  160. with pytest.raises(ValueError) as excinfo:
  161. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=True, num_samples=None)
  162. ds3.use_sampler(testsampler)
  163. assert excinfo.type == 'ValueError'
  164. with pytest.raises(ValueError) as excinfo:
  165. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
  166. ds3.use_sampler(testsampler)
  167. assert excinfo.type == 'ValueError'
  168. def test_imagefolder_padded():
  169. DATA_DIR = "../data/dataset/testPK/data"
  170. data = ds.ImageFolderDatasetV2(DATA_DIR)
  171. data1 = [{'image': np.zeros(1, np.uint8), 'label': np.array(0, np.int32)},
  172. {'image': np.zeros(2, np.uint8), 'label': np.array(1, np.int32)},
  173. {'image': np.zeros(3, np.uint8), 'label': np.array(0, np.int32)},
  174. {'image': np.zeros(4, np.uint8), 'label': np.array(1, np.int32)},
  175. {'image': np.zeros(5, np.uint8), 'label': np.array(0, np.int32)},
  176. {'image': np.zeros(6, np.uint8), 'label': np.array(1, np.int32)}]
  177. data2 = ds.PaddedDataset(data1)
  178. data3 = data + data2
  179. testsampler = ds.DistributedSampler(num_shards=5, shard_id=4, shuffle=False, num_samples=None)
  180. data3.use_sampler(testsampler)
  181. assert sum([1 for _ in data3]) == 10
  182. verify_list = []
  183. for ele in data3.create_dict_iterator():
  184. verify_list.append(len(ele['image']))
  185. assert verify_list[8] == 1
  186. assert verify_list[9] == 6
  187. def test_imagefolder_padded_with_decode():
  188. num_shards = 5
  189. count = 0
  190. for shard_id in range(num_shards):
  191. DATA_DIR = "../data/dataset/testPK/data"
  192. data = ds.ImageFolderDatasetV2(DATA_DIR)
  193. white_io = BytesIO()
  194. Image.new('RGB', (224, 224), (255, 255, 255)).save(white_io, 'JPEG')
  195. padded_sample = {}
  196. padded_sample['image'] = np.array(bytearray(white_io.getvalue()), dtype='uint8')
  197. padded_sample['label'] = np.array(-1, np.int32)
  198. white_samples = [padded_sample, padded_sample, padded_sample, padded_sample]
  199. data2 = ds.PaddedDataset(white_samples)
  200. data3 = data + data2
  201. testsampler = ds.DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=None)
  202. data3.use_sampler(testsampler)
  203. data3 = data3.map(input_columns="image", operations=V_C.Decode())
  204. shard_sample_count = 0
  205. for ele in data3.create_dict_iterator():
  206. print("label: {}".format(ele['label']))
  207. count += 1
  208. shard_sample_count += 1
  209. assert shard_sample_count in (9, 10)
  210. assert count == 48
  211. def test_imagefolder_padded_with_decode_and_get_dataset_size():
  212. num_shards = 5
  213. count = 0
  214. for shard_id in range(num_shards):
  215. DATA_DIR = "../data/dataset/testPK/data"
  216. data = ds.ImageFolderDatasetV2(DATA_DIR)
  217. white_io = BytesIO()
  218. Image.new('RGB', (224, 224), (255, 255, 255)).save(white_io, 'JPEG')
  219. padded_sample = {}
  220. padded_sample['image'] = np.array(bytearray(white_io.getvalue()), dtype='uint8')
  221. padded_sample['label'] = np.array(-1, np.int32)
  222. white_samples = [padded_sample, padded_sample, padded_sample, padded_sample]
  223. data2 = ds.PaddedDataset(white_samples)
  224. data3 = data + data2
  225. testsampler = ds.DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=None)
  226. data3.use_sampler(testsampler)
  227. shard_dataset_size = data3.get_dataset_size()
  228. data3 = data3.map(input_columns="image", operations=V_C.Decode())
  229. shard_sample_count = 0
  230. for ele in data3.create_dict_iterator():
  231. print("label: {}".format(ele['label']))
  232. count += 1
  233. shard_sample_count += 1
  234. assert shard_sample_count in (9, 10)
  235. assert shard_dataset_size == shard_sample_count
  236. assert count == 48
  237. def test_more_shard_padded():
  238. result_list = []
  239. for i in range(8):
  240. result_list.append(1)
  241. result_list.append(0)
  242. data1 = ds.GeneratorDataset(generator_5, ["col1"])
  243. data2 = ds.GeneratorDataset(generator_8, ["col1"])
  244. data3 = data1 + data2
  245. vertifyList = []
  246. numShard = 9
  247. for i in range(numShard):
  248. tem_list = []
  249. testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
  250. data3.use_sampler(testsampler)
  251. for item in data3.create_dict_iterator():
  252. tem_list.append(item['col1'])
  253. vertifyList.append(tem_list)
  254. assert [len(ele) for ele in vertifyList] == result_list
  255. vertifyList1 = []
  256. result_list1 = []
  257. for i in range(8):
  258. result_list1.append([i+1])
  259. result_list1.append([])
  260. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  261. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  262. {'image': np.zeros(5, np.uint8)}]
  263. data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
  264. {'image': np.zeros(8, np.uint8)}]
  265. ds1 = ds.PaddedDataset(data1)
  266. ds2 = ds.PaddedDataset(data2)
  267. ds3 = ds1 + ds2
  268. for i in range(numShard):
  269. tem_list = []
  270. testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
  271. ds3.use_sampler(testsampler)
  272. for item in ds3.create_dict_iterator():
  273. tem_list.append(len(item['image']))
  274. vertifyList1.append(tem_list)
  275. assert vertifyList1 == result_list1
  276. def get_data(dir_name):
  277. """
  278. usage: get data from imagenet dataset
  279. params:
  280. dir_name: directory containing folder images and annotation information
  281. """
  282. if not os.path.isdir(dir_name):
  283. raise IOError("Directory {} not exists".format(dir_name))
  284. img_dir = os.path.join(dir_name, "images")
  285. ann_file = os.path.join(dir_name, "annotation.txt")
  286. with open(ann_file, "r") as file_reader:
  287. lines = file_reader.readlines()
  288. data_list = []
  289. for i, line in enumerate(lines):
  290. try:
  291. filename, label = line.split(",")
  292. label = label.strip("\n")
  293. with open(os.path.join(img_dir, filename), "rb") as file_reader:
  294. img = file_reader.read()
  295. data_json = {"id": i,
  296. "file_name": filename,
  297. "data": img,
  298. "label": int(label)}
  299. data_list.append(data_json)
  300. except FileNotFoundError:
  301. continue
  302. return data_list
  303. @pytest.fixture(name="remove_mindrecord_file")
  304. def add_and_remove_cv_file():
  305. """add/remove cv file"""
  306. paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
  307. for x in range(FILES_NUM)]
  308. try:
  309. for x in paths:
  310. if os.path.exists("{}".format(x)):
  311. os.remove("{}".format(x))
  312. if os.path.exists("{}.db".format(x)):
  313. os.remove("{}.db".format(x))
  314. writer = FileWriter(CV_FILE_NAME, FILES_NUM)
  315. data = get_data(CV_DIR_NAME)
  316. cv_schema_json = {"id": {"type": "int32"},
  317. "file_name": {"type": "string"},
  318. "label": {"type": "int32"},
  319. "data": {"type": "bytes"}}
  320. writer.add_schema(cv_schema_json, "img_schema")
  321. writer.add_index(["file_name", "label"])
  322. writer.write_raw_data(data)
  323. writer.commit()
  324. yield "yield_cv_data"
  325. except Exception as error:
  326. for x in paths:
  327. os.remove("{}".format(x))
  328. os.remove("{}.db".format(x))
  329. raise error
  330. else:
  331. for x in paths:
  332. os.remove("{}".format(x))
  333. os.remove("{}.db".format(x))
  334. def test_Mindrecord_Padded(remove_mindrecord_file):
  335. result_list = []
  336. verify_list = [[1, 2], [3, 4], [5, 11], [6, 12], [7, 13], [8, 14], [9], [10]]
  337. num_readers = 4
  338. data_set = ds.MindDataset(CV_FILE_NAME + "0", ['file_name'], num_readers, shuffle=False)
  339. data1 = [{'file_name': np.array(b'image_00011.jpg', dtype='|S15')},
  340. {'file_name': np.array(b'image_00012.jpg', dtype='|S15')},
  341. {'file_name': np.array(b'image_00013.jpg', dtype='|S15')},
  342. {'file_name': np.array(b'image_00014.jpg', dtype='|S15')}]
  343. ds1 = ds.PaddedDataset(data1)
  344. ds2 = data_set + ds1
  345. shard_num = 8
  346. for i in range(shard_num):
  347. testsampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  348. ds2.use_sampler(testsampler)
  349. tem_list = []
  350. for ele in ds2.create_dict_iterator():
  351. tem_list.append(int(ele['file_name'].tostring().decode().lstrip('image_').rstrip('.jpg')))
  352. result_list.append(tem_list)
  353. assert result_list == verify_list
  354. if __name__ == '__main__':
  355. test_TFRecord_Padded()
  356. test_GeneratorDataSet_Padded()
  357. test_Reapeat_afterPadded()
  358. test_bath_afterPadded()
  359. test_Unevenly_distributed()
  360. test_three_datasets_connected()
  361. test_raise_error()
  362. test_imagefolden_padded()
  363. test_more_shard_padded()
  364. test_Mindrecord_Padded(add_and_remove_cv_file)