You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_paddeddataset.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. import os
  2. import numpy as np
  3. import pytest
  4. import mindspore.dataset as ds
  5. from mindspore.mindrecord import FileWriter
  6. FILES_NUM = 4
  7. CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
  8. CV_DIR_NAME = "../data/mindrecord/testImageNetData"
  9. def generator_5():
  10. for i in range(0, 5):
  11. yield (np.array([i]),)
  12. def generator_8():
  13. for i in range(5, 8):
  14. yield (np.array([i]),)
  15. def generator_10():
  16. for i in range(0, 10):
  17. yield (np.array([i]),)
  18. def generator_20():
  19. for i in range(10, 20):
  20. yield (np.array([i]),)
  21. def generator_30():
  22. for i in range(20, 30):
  23. yield (np.array([i]),)
  24. def test_TFRecord_Padded():
  25. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  26. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  27. result_list = [[159109, 2], [192607, 3], [179251, 4], [1, 5]]
  28. verify_list = []
  29. shard_num = 4
  30. for i in range(shard_num):
  31. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"],
  32. shuffle=False, shard_equal_rows=True)
  33. padded_samples = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  34. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  35. {'image': np.zeros(5, np.uint8)}]
  36. padded_ds = ds.PaddedDataset(padded_samples)
  37. concat_ds = data + padded_ds
  38. testsampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  39. concat_ds.use_sampler(testsampler)
  40. shard_list = []
  41. for item in concat_ds.create_dict_iterator():
  42. shard_list.append(len(item['image']))
  43. verify_list.append(shard_list)
  44. assert verify_list == result_list
  45. def test_GeneratorDataSet_Padded():
  46. result_list = []
  47. for i in range(10):
  48. tem_list = []
  49. tem_list.append(i)
  50. tem_list.append(10+i)
  51. result_list.append(tem_list)
  52. verify_list = []
  53. data1 = ds.GeneratorDataset(generator_20, ["col1"])
  54. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  55. data3 = data2 + data1
  56. shard_num = 10
  57. for i in range(shard_num):
  58. distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  59. data3.use_sampler(distributed_sampler)
  60. tem_list = []
  61. for ele in data3.create_dict_iterator():
  62. tem_list.append(ele['col1'][0])
  63. verify_list.append(tem_list)
  64. assert verify_list == result_list
  65. def test_Reapeat_afterPadded():
  66. result_list = [1, 3, 5, 7]
  67. verify_list = []
  68. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  69. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  70. {'image': np.zeros(5, np.uint8)}]
  71. data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
  72. {'image': np.zeros(8, np.uint8)}]
  73. ds1 = ds.PaddedDataset(data1)
  74. ds2 = ds.PaddedDataset(data2)
  75. ds3 = ds1 + ds2
  76. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
  77. ds3.use_sampler(testsampler)
  78. repeat_num = 2
  79. ds3 = ds3.repeat(repeat_num)
  80. for item in ds3.create_dict_iterator():
  81. verify_list.append(len(item['image']))
  82. assert verify_list == result_list * repeat_num
  83. def test_bath_afterPadded():
  84. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
  85. {'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
  86. {'image': np.zeros(1, np.uint8)}]
  87. data2 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
  88. {'image': np.zeros(1, np.uint8)}]
  89. ds1 = ds.PaddedDataset(data1)
  90. ds2 = ds.PaddedDataset(data2)
  91. ds3 = ds1 + ds2
  92. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
  93. ds3.use_sampler(testsampler)
  94. ds4 = ds3.batch(2)
  95. assert sum([1 for _ in ds4]) == 2
  96. def test_Unevenly_distributed():
  97. result_list = [[1, 4, 7], [2, 5, 8], [3, 6]]
  98. verify_list = []
  99. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  100. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  101. {'image': np.zeros(5, np.uint8)}]
  102. data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
  103. {'image': np.zeros(8, np.uint8)}]
  104. testsampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=1)
  105. ds1 = ds.PaddedDataset(data1)
  106. ds2 = ds.PaddedDataset(data2)
  107. ds3 = ds1 + ds2
  108. numShard = 3
  109. for i in range(numShard):
  110. tem_list = []
  111. testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
  112. ds3.use_sampler(testsampler)
  113. for item in ds3.create_dict_iterator():
  114. tem_list.append(len(item['image']))
  115. verify_list.append(tem_list)
  116. assert verify_list == result_list
  117. def test_three_datasets_connected():
  118. result_list = []
  119. for i in range(10):
  120. tem_list = []
  121. tem_list.append(i)
  122. tem_list.append(10 + i)
  123. tem_list.append(20 + i)
  124. result_list.append(tem_list)
  125. verify_list = []
  126. data1 = ds.GeneratorDataset(generator_10, ["col1"])
  127. data2 = ds.GeneratorDataset(generator_20, ["col1"])
  128. data3 = ds.GeneratorDataset(generator_30, ["col1"])
  129. data4 = data1 + data2 + data3
  130. shard_num = 10
  131. for i in range(shard_num):
  132. distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  133. data4.use_sampler(distributed_sampler)
  134. tem_list = []
  135. for ele in data4.create_dict_iterator():
  136. tem_list.append(ele['col1'][0])
  137. verify_list.append(tem_list)
  138. assert verify_list == result_list
  139. def test_raise_error():
  140. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  141. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  142. {'image': np.zeros(5, np.uint8)}]
  143. data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
  144. {'image': np.zeros(8, np.uint8)}]
  145. ds1 = ds.PaddedDataset(data1)
  146. ds4 = ds1.batch(2)
  147. ds2 = ds.PaddedDataset(data2)
  148. ds3 = ds4 + ds2
  149. with pytest.raises(TypeError) as excinfo:
  150. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
  151. ds3.use_sampler(testsampler)
  152. assert excinfo.type == 'TypeError'
  153. with pytest.raises(TypeError) as excinfo:
  154. otherSampler = ds.SequentialSampler()
  155. ds3.use_sampler(otherSampler)
  156. assert excinfo.type == 'TypeError'
  157. with pytest.raises(ValueError) as excinfo:
  158. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=True, num_samples=None)
  159. ds3.use_sampler(testsampler)
  160. assert excinfo.type == 'ValueError'
  161. with pytest.raises(ValueError) as excinfo:
  162. testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
  163. ds3.use_sampler(testsampler)
  164. assert excinfo.type == 'ValueError'
  165. def test_imagefolden_padded():
  166. DATA_DIR = "../data/dataset/testPK/data"
  167. data = ds.ImageFolderDatasetV2(DATA_DIR)
  168. data1 = [{'image': np.zeros(1, np.uint8), 'label': np.array(0, np.int32)},
  169. {'image': np.zeros(2, np.uint8), 'label': np.array(1, np.int32)},
  170. {'image': np.zeros(3, np.uint8), 'label': np.array(0, np.int32)},
  171. {'image': np.zeros(4, np.uint8), 'label': np.array(1, np.int32)},
  172. {'image': np.zeros(5, np.uint8), 'label': np.array(0, np.int32)},
  173. {'image': np.zeros(6, np.uint8), 'label': np.array(1, np.int32)}]
  174. data2 = ds.PaddedDataset(data1)
  175. data3 = data + data2
  176. testsampler = ds.DistributedSampler(num_shards=5, shard_id=4, shuffle=False, num_samples=None)
  177. data3.use_sampler(testsampler)
  178. assert sum([1 for _ in data3]) == 10
  179. verify_list = []
  180. for ele in data3.create_dict_iterator():
  181. verify_list.append(len(ele['image']))
  182. assert verify_list[8] == 1
  183. assert verify_list[9] == 6
  184. def test_more_shard_padded():
  185. result_list = []
  186. for i in range(8):
  187. result_list.append(1)
  188. result_list.append(0)
  189. data1 = ds.GeneratorDataset(generator_5, ["col1"])
  190. data2 = ds.GeneratorDataset(generator_8, ["col1"])
  191. data3 = data1 + data2
  192. vertifyList = []
  193. numShard = 9
  194. for i in range(numShard):
  195. tem_list = []
  196. testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
  197. data3.use_sampler(testsampler)
  198. for item in data3.create_dict_iterator():
  199. tem_list.append(item['col1'])
  200. vertifyList.append(tem_list)
  201. assert [len(ele) for ele in vertifyList] == result_list
  202. vertifyList1 = []
  203. result_list1 = []
  204. for i in range(8):
  205. result_list1.append([i+1])
  206. result_list1.append([])
  207. data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
  208. {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
  209. {'image': np.zeros(5, np.uint8)}]
  210. data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
  211. {'image': np.zeros(8, np.uint8)}]
  212. ds1 = ds.PaddedDataset(data1)
  213. ds2 = ds.PaddedDataset(data2)
  214. ds3 = ds1 + ds2
  215. for i in range(numShard):
  216. tem_list = []
  217. testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
  218. ds3.use_sampler(testsampler)
  219. for item in ds3.create_dict_iterator():
  220. tem_list.append(len(item['image']))
  221. vertifyList1.append(tem_list)
  222. assert vertifyList1 == result_list1
  223. def get_data(dir_name):
  224. """
  225. usage: get data from imagenet dataset
  226. params:
  227. dir_name: directory containing folder images and annotation information
  228. """
  229. if not os.path.isdir(dir_name):
  230. raise IOError("Directory {} not exists".format(dir_name))
  231. img_dir = os.path.join(dir_name, "images")
  232. ann_file = os.path.join(dir_name, "annotation.txt")
  233. with open(ann_file, "r") as file_reader:
  234. lines = file_reader.readlines()
  235. data_list = []
  236. for i, line in enumerate(lines):
  237. try:
  238. filename, label = line.split(",")
  239. label = label.strip("\n")
  240. with open(os.path.join(img_dir, filename), "rb") as file_reader:
  241. img = file_reader.read()
  242. data_json = {"id": i,
  243. "file_name": filename,
  244. "data": img,
  245. "label": int(label)}
  246. data_list.append(data_json)
  247. except FileNotFoundError:
  248. continue
  249. return data_list
  250. @pytest.fixture(name="remove_mindrecord_file")
  251. def add_and_remove_cv_file():
  252. """add/remove cv file"""
  253. paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
  254. for x in range(FILES_NUM)]
  255. try:
  256. for x in paths:
  257. if os.path.exists("{}".format(x)):
  258. os.remove("{}".format(x))
  259. if os.path.exists("{}.db".format(x)):
  260. os.remove("{}.db".format(x))
  261. writer = FileWriter(CV_FILE_NAME, FILES_NUM)
  262. data = get_data(CV_DIR_NAME)
  263. cv_schema_json = {"id": {"type": "int32"},
  264. "file_name": {"type": "string"},
  265. "label": {"type": "int32"},
  266. "data": {"type": "bytes"}}
  267. writer.add_schema(cv_schema_json, "img_schema")
  268. writer.add_index(["file_name", "label"])
  269. writer.write_raw_data(data)
  270. writer.commit()
  271. yield "yield_cv_data"
  272. except Exception as error:
  273. for x in paths:
  274. os.remove("{}".format(x))
  275. os.remove("{}.db".format(x))
  276. raise error
  277. else:
  278. for x in paths:
  279. os.remove("{}".format(x))
  280. os.remove("{}.db".format(x))
  281. def test_Mindrecord_Padded(remove_mindrecord_file):
  282. result_list = []
  283. verify_list = [[1, 2], [3, 4], [5, 11], [6, 12], [7, 13], [8, 14], [9], [10]]
  284. num_readers = 4
  285. data_set = ds.MindDataset(CV_FILE_NAME + "0", ['file_name'], num_readers, shuffle=False)
  286. data1 = [{'file_name': np.array(b'image_00011.jpg', dtype='|S15')},
  287. {'file_name': np.array(b'image_00012.jpg', dtype='|S15')},
  288. {'file_name': np.array(b'image_00013.jpg', dtype='|S15')},
  289. {'file_name': np.array(b'image_00014.jpg', dtype='|S15')}]
  290. ds1 = ds.PaddedDataset(data1)
  291. ds2 = data_set + ds1
  292. shard_num = 8
  293. for i in range(shard_num):
  294. testsampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  295. ds2.use_sampler(testsampler)
  296. tem_list = []
  297. for ele in ds2.create_dict_iterator():
  298. tem_list.append(int(ele['file_name'].tostring().decode().lstrip('image_').rstrip('.jpg')))
  299. result_list.append(tem_list)
  300. assert result_list == verify_list
  301. if __name__ == '__main__':
  302. test_TFRecord_Padded()
  303. test_GeneratorDataSet_Padded()
  304. test_Reapeat_afterPadded()
  305. test_bath_afterPadded()
  306. test_Unevenly_distributed()
  307. test_three_datasets_connected()
  308. test_raise_error()
  309. test_imagefolden_padded()
  310. test_more_shard_padded()
  311. test_Mindrecord_Padded(add_and_remove_cv_file)