You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_paddeddataset.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. from io import BytesIO
  2. import os
  3. import numpy as np
  4. import pytest
  5. import mindspore.dataset as ds
  6. from mindspore.mindrecord import FileWriter
  7. import mindspore.dataset.transforms.vision.c_transforms as V_C
  8. from PIL import Image
  9. FILES_NUM = 4
  10. CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
  11. CV_DIR_NAME = "../data/mindrecord/testImageNetData"
  12. def generator_5():
  13. for i in range(0, 5):
  14. yield (np.array([i]),)
  15. def generator_8():
  16. for i in range(5, 8):
  17. yield (np.array([i]),)
  18. def generator_10():
  19. for i in range(0, 10):
  20. yield (np.array([i]),)
  21. def generator_20():
  22. for i in range(10, 20):
  23. yield (np.array([i]),)
  24. def generator_30():
  25. for i in range(20, 30):
  26. yield (np.array([i]),)
  27. def test_TFRecord_Padded():
  28. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  29. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  30. result_list = [[159109, 2], [192607, 3], [179251, 4], [1, 5]]
  31. verify_list = []
  32. shard_num = 4
  33. for i in range(shard_num):
  34. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"],
  35. shuffle=False, shard_equal_rows=True)
  36. padded_samples = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  37. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  38. {'image': np.zeros(5, np.uint8)}]
  39. padded_ds = ds.PaddedDataset(padded_samples)
  40. concat_ds = data + padded_ds
  41. testsampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  42. concat_ds.use_sampler(testsampler)
  43. shard_list = []
  44. for item in concat_ds.create_dict_iterator():
  45. shard_list.append(len(item['image']))
  46. verify_list.append(shard_list)
  47. assert verify_list == result_list
  48. def test_GeneratorDataSet_Padded():
  49. result_list = []
  50. for i in range(10):
  51. tem_list = []
  52. tem_list.append(i)
  53. tem_list.append(10+i)
  54. result_list.append(tem_list)
  55. verify_list = []
  56. data1 = ds.GeneratorDataset(generator_20, ["col1"])
  57. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  58. data3 = data2 + data1
  59. shard_num = 10
  60. for i in range(shard_num):
  61. distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  62. data3.use_sampler(distributed_sampler)
  63. tem_list = []
  64. for ele in data3.create_dict_iterator():
  65. tem_list.append(ele['col1'][0])
  66. verify_list.append(tem_list)
  67. assert verify_list == result_list
  68. def test_Reapeat_afterPadded():
  69. result_list = [1, 3, 5, 7]
  70. verify_list = []
  71. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  72. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  73. {'image': np.zeros(5, np.uint8)}]
  74. data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
  75. {'image': np.zeros(8, np.uint8)}]
  76. ds1 = ds.PaddedDataset(data1)
  77. ds2 = ds.PaddedDataset(data2)
  78. ds3 = ds1 + ds2
  79. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
  80. ds3.use_sampler(testsampler)
  81. repeat_num = 2
  82. ds3 = ds3.repeat(repeat_num)
  83. for item in ds3.create_dict_iterator():
  84. verify_list.append(len(item['image']))
  85. assert verify_list == result_list * repeat_num
  86. def test_bath_afterPadded():
  87. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
  88. {'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
  89. {'image': np.zeros(1, np.uint8)}]
  90. data2 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
  91. {'image': np.zeros(1, np.uint8)}]
  92. ds1 = ds.PaddedDataset(data1)
  93. ds2 = ds.PaddedDataset(data2)
  94. ds3 = ds1 + ds2
  95. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
  96. ds3.use_sampler(testsampler)
  97. ds4 = ds3.batch(2)
  98. assert sum([1 for _ in ds4]) == 2
  99. def test_Unevenly_distributed():
  100. result_list = [[1, 4, 7], [2, 5, 8], [3, 6]]
  101. verify_list = []
  102. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  103. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  104. {'image': np.zeros(5, np.uint8)}]
  105. data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
  106. {'image': np.zeros(8, np.uint8)}]
  107. testsampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=1)
  108. ds1 = ds.PaddedDataset(data1)
  109. ds2 = ds.PaddedDataset(data2)
  110. ds3 = ds1 + ds2
  111. numShard = 3
  112. for i in range(numShard):
  113. tem_list = []
  114. testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
  115. ds3.use_sampler(testsampler)
  116. for item in ds3.create_dict_iterator():
  117. tem_list.append(len(item['image']))
  118. verify_list.append(tem_list)
  119. assert verify_list == result_list
  120. def test_three_datasets_connected():
  121. result_list = []
  122. for i in range(10):
  123. tem_list = []
  124. tem_list.append(i)
  125. tem_list.append(10 + i)
  126. tem_list.append(20 + i)
  127. result_list.append(tem_list)
  128. verify_list = []
  129. data1 = ds.GeneratorDataset(generator_10, ["col1"])
  130. data2 = ds.GeneratorDataset(generator_20, ["col1"])
  131. data3 = ds.GeneratorDataset(generator_30, ["col1"])
  132. data4 = data1 + data2 + data3
  133. shard_num = 10
  134. for i in range(shard_num):
  135. distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  136. data4.use_sampler(distributed_sampler)
  137. tem_list = []
  138. for ele in data4.create_dict_iterator():
  139. tem_list.append(ele['col1'][0])
  140. verify_list.append(tem_list)
  141. assert verify_list == result_list
  142. def test_raise_error():
  143. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  144. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  145. {'image': np.zeros(5, np.uint8)}]
  146. data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
  147. {'image': np.zeros(8, np.uint8)}]
  148. ds1 = ds.PaddedDataset(data1)
  149. ds4 = ds1.batch(2)
  150. ds2 = ds.PaddedDataset(data2)
  151. ds3 = ds4 + ds2
  152. with pytest.raises(TypeError) as excinfo:
  153. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
  154. ds3.use_sampler(testsampler)
  155. assert excinfo.type == 'TypeError'
  156. with pytest.raises(TypeError) as excinfo:
  157. otherSampler = ds.SequentialSampler()
  158. ds3.use_sampler(otherSampler)
  159. assert excinfo.type == 'TypeError'
  160. with pytest.raises(ValueError) as excinfo:
  161. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=True, num_samples=None)
  162. ds3.use_sampler(testsampler)
  163. assert excinfo.type == 'ValueError'
  164. with pytest.raises(ValueError) as excinfo:
  165. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
  166. ds3.use_sampler(testsampler)
  167. assert excinfo.type == 'ValueError'
  168. def test_imagefolder_padded():
  169. DATA_DIR = "../data/dataset/testPK/data"
  170. data = ds.ImageFolderDatasetV2(DATA_DIR)
  171. data1 = [{'image': np.zeros(1, np.uint8), 'label': np.array(0, np.int32)},
  172. {'image': np.zeros(2, np.uint8), 'label': np.array(1, np.int32)},
  173. {'image': np.zeros(3, np.uint8), 'label': np.array(0, np.int32)},
  174. {'image': np.zeros(4, np.uint8), 'label': np.array(1, np.int32)},
  175. {'image': np.zeros(5, np.uint8), 'label': np.array(0, np.int32)},
  176. {'image': np.zeros(6, np.uint8), 'label': np.array(1, np.int32)}]
  177. data2 = ds.PaddedDataset(data1)
  178. data3 = data + data2
  179. testsampler = ds.DistributedSampler(num_shards=5, shard_id=4, shuffle=False, num_samples=None)
  180. data3.use_sampler(testsampler)
  181. assert sum([1 for _ in data3]) == 10
  182. verify_list = []
  183. for ele in data3.create_dict_iterator():
  184. verify_list.append(len(ele['image']))
  185. assert verify_list[8] == 1
  186. assert verify_list[9] == 6
  187. def test_imagefolder_padded_with_decode():
  188. DATA_DIR = "../data/dataset/testPK/data"
  189. data = ds.ImageFolderDatasetV2(DATA_DIR)
  190. white_io = BytesIO()
  191. Image.new('RGB', (224, 224), (255, 255, 255)).save(white_io, 'JPEG')
  192. padded_sample = {}
  193. padded_sample['image'] = np.array(bytearray(white_io), dtype='uint8')
  194. padded_sample['label'] = np.array(-1, np.int32)
  195. white_samples = [padded_sample, padded_sample, padded_sample, padded_sample]
  196. data2 = ds.PaddedDataset(white_samples)
  197. data3 = data + data2
  198. num_shards = 5
  199. count = 0
  200. for shard_id in range(num_shards):
  201. testsampler = ds.DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=None)
  202. data3.use_sampler(testsampler)
  203. data3.map(input_columns="image", operations=V_C.Decode())
  204. for ele in data3.create_dict_iterator():
  205. print("label: {}".format(ele['label']))
  206. count += 1
  207. assert count == 48
  208. def test_more_shard_padded():
  209. result_list = []
  210. for i in range(8):
  211. result_list.append(1)
  212. result_list.append(0)
  213. data1 = ds.GeneratorDataset(generator_5, ["col1"])
  214. data2 = ds.GeneratorDataset(generator_8, ["col1"])
  215. data3 = data1 + data2
  216. vertifyList = []
  217. numShard = 9
  218. for i in range(numShard):
  219. tem_list = []
  220. testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
  221. data3.use_sampler(testsampler)
  222. for item in data3.create_dict_iterator():
  223. tem_list.append(item['col1'])
  224. vertifyList.append(tem_list)
  225. assert [len(ele) for ele in vertifyList] == result_list
  226. vertifyList1 = []
  227. result_list1 = []
  228. for i in range(8):
  229. result_list1.append([i+1])
  230. result_list1.append([])
  231. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  232. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  233. {'image': np.zeros(5, np.uint8)}]
  234. data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
  235. {'image': np.zeros(8, np.uint8)}]
  236. ds1 = ds.PaddedDataset(data1)
  237. ds2 = ds.PaddedDataset(data2)
  238. ds3 = ds1 + ds2
  239. for i in range(numShard):
  240. tem_list = []
  241. testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
  242. ds3.use_sampler(testsampler)
  243. for item in ds3.create_dict_iterator():
  244. tem_list.append(len(item['image']))
  245. vertifyList1.append(tem_list)
  246. assert vertifyList1 == result_list1
  247. def get_data(dir_name):
  248. """
  249. usage: get data from imagenet dataset
  250. params:
  251. dir_name: directory containing folder images and annotation information
  252. """
  253. if not os.path.isdir(dir_name):
  254. raise IOError("Directory {} not exists".format(dir_name))
  255. img_dir = os.path.join(dir_name, "images")
  256. ann_file = os.path.join(dir_name, "annotation.txt")
  257. with open(ann_file, "r") as file_reader:
  258. lines = file_reader.readlines()
  259. data_list = []
  260. for i, line in enumerate(lines):
  261. try:
  262. filename, label = line.split(",")
  263. label = label.strip("\n")
  264. with open(os.path.join(img_dir, filename), "rb") as file_reader:
  265. img = file_reader.read()
  266. data_json = {"id": i,
  267. "file_name": filename,
  268. "data": img,
  269. "label": int(label)}
  270. data_list.append(data_json)
  271. except FileNotFoundError:
  272. continue
  273. return data_list
  274. @pytest.fixture(name="remove_mindrecord_file")
  275. def add_and_remove_cv_file():
  276. """add/remove cv file"""
  277. paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
  278. for x in range(FILES_NUM)]
  279. try:
  280. for x in paths:
  281. if os.path.exists("{}".format(x)):
  282. os.remove("{}".format(x))
  283. if os.path.exists("{}.db".format(x)):
  284. os.remove("{}.db".format(x))
  285. writer = FileWriter(CV_FILE_NAME, FILES_NUM)
  286. data = get_data(CV_DIR_NAME)
  287. cv_schema_json = {"id": {"type": "int32"},
  288. "file_name": {"type": "string"},
  289. "label": {"type": "int32"},
  290. "data": {"type": "bytes"}}
  291. writer.add_schema(cv_schema_json, "img_schema")
  292. writer.add_index(["file_name", "label"])
  293. writer.write_raw_data(data)
  294. writer.commit()
  295. yield "yield_cv_data"
  296. except Exception as error:
  297. for x in paths:
  298. os.remove("{}".format(x))
  299. os.remove("{}.db".format(x))
  300. raise error
  301. else:
  302. for x in paths:
  303. os.remove("{}".format(x))
  304. os.remove("{}.db".format(x))
  305. def test_Mindrecord_Padded(remove_mindrecord_file):
  306. result_list = []
  307. verify_list = [[1, 2], [3, 4], [5, 11], [6, 12], [7, 13], [8, 14], [9], [10]]
  308. num_readers = 4
  309. data_set = ds.MindDataset(CV_FILE_NAME + "0", ['file_name'], num_readers, shuffle=False)
  310. data1 = [{'file_name': np.array(b'image_00011.jpg', dtype='|S15')},
  311. {'file_name': np.array(b'image_00012.jpg', dtype='|S15')},
  312. {'file_name': np.array(b'image_00013.jpg', dtype='|S15')},
  313. {'file_name': np.array(b'image_00014.jpg', dtype='|S15')}]
  314. ds1 = ds.PaddedDataset(data1)
  315. ds2 = data_set + ds1
  316. shard_num = 8
  317. for i in range(shard_num):
  318. testsampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  319. ds2.use_sampler(testsampler)
  320. tem_list = []
  321. for ele in ds2.create_dict_iterator():
  322. tem_list.append(int(ele['file_name'].tostring().decode().lstrip('image_').rstrip('.jpg')))
  323. result_list.append(tem_list)
  324. assert result_list == verify_list
  325. if __name__ == '__main__':
  326. test_TFRecord_Padded()
  327. test_GeneratorDataSet_Padded()
  328. test_Reapeat_afterPadded()
  329. test_bath_afterPadded()
  330. test_Unevenly_distributed()
  331. test_three_datasets_connected()
  332. test_raise_error()
  333. test_imagefolden_padded()
  334. test_more_shard_padded()
  335. test_Mindrecord_Padded(add_and_remove_cv_file)